| @@ -121,7 +121,9 @@ def get(): | |||
| "important_kwd") | |||
| def set(): | |||
| req = request.json | |||
| d = {"id": req["chunk_id"], "content_with_weight": req["content_with_weight"]} | |||
| d = { | |||
| "id": req["chunk_id"], | |||
| "content_with_weight": req["content_with_weight"]} | |||
| d["content_ltks"] = huqie.qie(req["content_with_weight"]) | |||
| d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) | |||
| d["important_kwd"] = req["important_kwd"] | |||
| @@ -140,10 +142,16 @@ def set(): | |||
| return get_data_error_result(retmsg="Document not found!") | |||
| if doc.parser_id == ParserType.QA: | |||
| arr = [t for t in re.split(r"[\n\t]", req["content_with_weight"]) if len(t) > 1] | |||
| if len(arr) != 2: return get_data_error_result(retmsg="Q&A must be separated by TAB/ENTER key.") | |||
| arr = [ | |||
| t for t in re.split( | |||
| r"[\n\t]", | |||
| req["content_with_weight"]) if len(t) > 1] | |||
| if len(arr) != 2: | |||
| return get_data_error_result( | |||
| retmsg="Q&A must be separated by TAB/ENTER key.") | |||
| q, a = rmPrefix(arr[0]), rmPrefix[arr[1]] | |||
| d = beAdoc(d, arr[0], arr[1], not any([huqie.is_chinese(t) for t in q + a])) | |||
| d = beAdoc(d, arr[0], arr[1], not any( | |||
| [huqie.is_chinese(t) for t in q + a])) | |||
| v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) | |||
| v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] | |||
| @@ -177,7 +185,8 @@ def switch(): | |||
| def rm(): | |||
| req = request.json | |||
| try: | |||
| if not ELASTICSEARCH.deleteByQuery(Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)): | |||
| if not ELASTICSEARCH.deleteByQuery( | |||
| Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)): | |||
| return get_data_error_result(retmsg="Index updating failure") | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| @@ -100,7 +100,10 @@ def rm(): | |||
| def list_convsersation(): | |||
| dialog_id = request.args["dialog_id"] | |||
| try: | |||
| convs = ConversationService.query(dialog_id=dialog_id, order_by=ConversationService.model.create_time, reverse=True) | |||
| convs = ConversationService.query( | |||
| dialog_id=dialog_id, | |||
| order_by=ConversationService.model.create_time, | |||
| reverse=True) | |||
| convs = [d.to_dict() for d in convs] | |||
| return get_json_result(data=convs) | |||
| except Exception as e: | |||
| @@ -111,19 +114,24 @@ def message_fit_in(msg, max_length=4000): | |||
| def count(): | |||
| nonlocal msg | |||
| tks_cnts = [] | |||
| for m in msg: tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])}) | |||
| for m in msg: | |||
| tks_cnts.append( | |||
| {"role": m["role"], "count": num_tokens_from_string(m["content"])}) | |||
| total = 0 | |||
| for m in tks_cnts: total += m["count"] | |||
| for m in tks_cnts: | |||
| total += m["count"] | |||
| return total | |||
| c = count() | |||
| if c < max_length: return c, msg | |||
| if c < max_length: | |||
| return c, msg | |||
| msg_ = [m for m in msg[:-1] if m.role == "system"] | |||
| msg_.append(msg[-1]) | |||
| msg = msg_ | |||
| c = count() | |||
| if c < max_length: return c, msg | |||
| if c < max_length: | |||
| return c, msg | |||
| ll = num_tokens_from_string(msg_[0].content) | |||
| l = num_tokens_from_string(msg_[-1].content) | |||
| @@ -146,8 +154,10 @@ def completion(): | |||
| req = request.json | |||
| msg = [] | |||
| for m in req["messages"]: | |||
| if m["role"] == "system": continue | |||
| if m["role"] == "assistant" and not msg: continue | |||
| if m["role"] == "system": | |||
| continue | |||
| if m["role"] == "assistant" and not msg: | |||
| continue | |||
| msg.append({"role": m["role"], "content": m["content"]}) | |||
| try: | |||
| e, conv = ConversationService.get_by_id(req["conversation_id"]) | |||
| @@ -160,7 +170,8 @@ def completion(): | |||
| del req["conversation_id"] | |||
| del req["messages"] | |||
| ans = chat(dia, msg, **req) | |||
| if not conv.reference: conv.reference = [] | |||
| if not conv.reference: | |||
| conv.reference = [] | |||
| conv.reference.append(ans["reference"]) | |||
| conv.message.append({"role": "assistant", "content": ans["answer"]}) | |||
| ConversationService.update_by_id(conv.id, conv.to_dict()) | |||
| @@ -180,52 +191,67 @@ def chat(dialog, messages, **kwargs): | |||
| chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) | |||
| field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) | |||
| ## try to use sql if field mapping is good to go | |||
| # try to use sql if field mapping is good to go | |||
| if field_map: | |||
| chat_logger.info("Use SQL to retrieval:{}".format(questions[-1])) | |||
| return use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl) | |||
| prompt_config = dialog.prompt_config | |||
| for p in prompt_config["parameters"]: | |||
| if p["key"] == "knowledge": continue | |||
| if p["key"] not in kwargs and not p["optional"]: raise KeyError("Miss parameter: " + p["key"]) | |||
| if p["key"] == "knowledge": | |||
| continue | |||
| if p["key"] not in kwargs and not p["optional"]: | |||
| raise KeyError("Miss parameter: " + p["key"]) | |||
| if p["key"] not in kwargs: | |||
| prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ") | |||
| prompt_config["system"] = prompt_config["system"].replace( | |||
| "{%s}" % p["key"], " ") | |||
| for _ in range(len(questions)//2): | |||
| for _ in range(len(questions) // 2): | |||
| questions.append(questions[-1]) | |||
| if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]: | |||
| kbinfos = {"total":0, "chunks":[],"doc_aggs":[]} | |||
| kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} | |||
| else: | |||
| kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, | |||
| dialog.similarity_threshold, | |||
| dialog.vector_similarity_weight, top=1024, aggs=False) | |||
| dialog.similarity_threshold, | |||
| dialog.vector_similarity_weight, top=1024, aggs=False) | |||
| knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] | |||
| chat_logger.info("{}->{}".format(" ".join(questions), "\n->".join(knowledges))) | |||
| chat_logger.info( | |||
| "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) | |||
| if not knowledges and prompt_config.get("empty_response"): | |||
| return {"answer": prompt_config["empty_response"], "reference": kbinfos} | |||
| return { | |||
| "answer": prompt_config["empty_response"], "reference": kbinfos} | |||
| kwargs["knowledge"] = "\n".join(knowledges) | |||
| gen_conf = dialog.llm_setting | |||
| msg = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"] | |||
| msg = [{"role": m["role"], "content": m["content"]} | |||
| for m in messages if m["role"] != "system"] | |||
| used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97)) | |||
| if "max_tokens" in gen_conf: | |||
| gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count) | |||
| answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf) | |||
| chat_logger.info("User: {}|Assistant: {}".format(msg[-1]["content"], answer)) | |||
| gen_conf["max_tokens"] = min( | |||
| gen_conf["max_tokens"], | |||
| llm.max_tokens - used_token_count) | |||
| answer = chat_mdl.chat( | |||
| prompt_config["system"].format( | |||
| **kwargs), msg, gen_conf) | |||
| chat_logger.info("User: {}|Assistant: {}".format( | |||
| msg[-1]["content"], answer)) | |||
| if knowledges: | |||
| answer, idx = retrievaler.insert_citations(answer, | |||
| [ck["content_ltks"] for ck in kbinfos["chunks"]], | |||
| [ck["vector"] for ck in kbinfos["chunks"]], | |||
| embd_mdl, | |||
| tkweight=1 - dialog.vector_similarity_weight, | |||
| vtweight=dialog.vector_similarity_weight) | |||
| [ck["content_ltks"] | |||
| for ck in kbinfos["chunks"]], | |||
| [ck["vector"] | |||
| for ck in kbinfos["chunks"]], | |||
| embd_mdl, | |||
| tkweight=1 - dialog.vector_similarity_weight, | |||
| vtweight=dialog.vector_similarity_weight) | |||
| idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) | |||
| kbinfos["doc_aggs"] = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] | |||
| kbinfos["doc_aggs"] = [ | |||
| d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] | |||
| for c in kbinfos["chunks"]: | |||
| if c.get("vector"): del c["vector"] | |||
| if c.get("vector"): | |||
| del c["vector"] | |||
| return {"answer": answer, "reference": kbinfos} | |||
| @@ -245,9 +271,11 @@ def use_sql(question, field_map, tenant_id, chat_mdl): | |||
| question | |||
| ) | |||
| tried_times = 0 | |||
| def get_table(): | |||
| nonlocal sys_prompt, user_promt, question, tried_times | |||
| sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.06}) | |||
| sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], { | |||
| "temperature": 0.06}) | |||
| print(user_promt, sql) | |||
| chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}") | |||
| sql = re.sub(r"[\r\n]+", " ", sql.lower()) | |||
| @@ -262,8 +290,10 @@ def use_sql(question, field_map, tenant_id, chat_mdl): | |||
| else: | |||
| flds = [] | |||
| for k in field_map.keys(): | |||
| if k in forbidden_select_fields4resume:continue | |||
| if len(flds) > 11:break | |||
| if k in forbidden_select_fields4resume: | |||
| continue | |||
| if len(flds) > 11: | |||
| break | |||
| flds.append(k) | |||
| sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:] | |||
| @@ -284,13 +314,13 @@ def use_sql(question, field_map, tenant_id, chat_mdl): | |||
| 问题如下: | |||
| {} | |||
| 你上一次给出的错误SQL如下: | |||
| {} | |||
| 后台报错如下: | |||
| {} | |||
| 请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。 | |||
| """.format( | |||
| index_name(tenant_id), | |||
| @@ -302,16 +332,24 @@ def use_sql(question, field_map, tenant_id, chat_mdl): | |||
| chat_logger.info("GET table: {}".format(tbl)) | |||
| print(tbl) | |||
| if tbl.get("error") or len(tbl["rows"]) == 0: return None, None | |||
| if tbl.get("error") or len(tbl["rows"]) == 0: | |||
| return None, None | |||
| docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"]) | |||
| docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"]) | |||
| clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)] | |||
| docid_idx = set([ii for ii, c in enumerate( | |||
| tbl["columns"]) if c["name"] == "doc_id"]) | |||
| docnm_idx = set([ii for ii, c in enumerate( | |||
| tbl["columns"]) if c["name"] == "docnm_kwd"]) | |||
| clmn_idx = [ii for ii in range( | |||
| len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)] | |||
| # compose markdown table | |||
| clmns = "|"+"|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|") | |||
| line = "|"+"|".join(["------" for _ in range(len(clmn_idx))]) + ("|------|" if docid_idx and docid_idx else "") | |||
| rows = ["|"+"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]] | |||
| clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], | |||
| tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|") | |||
| line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \ | |||
| ("|------|" if docid_idx and docid_idx else "") | |||
| rows = ["|" + | |||
| "|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + | |||
| "|" for r in tbl["rows"]] | |||
| if not docid_idx or not docnm_idx: | |||
| chat_logger.warning("SQL missing field: " + sql) | |||
| return "\n".join([clmns, line, "\n".join(rows)]), [] | |||
| @@ -328,5 +366,5 @@ def use_sql(question, field_map, tenant_id, chat_mdl): | |||
| return { | |||
| "answer": "\n".join([clmns, line, rows]), | |||
| "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]], | |||
| "doc_aggs":[{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]} | |||
| "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]} | |||
| } | |||
| @@ -55,7 +55,8 @@ def set_dialog(): | |||
| } | |||
| prompt_config = req.get("prompt_config", default_prompt) | |||
| if not prompt_config["system"]: prompt_config["system"] = default_prompt["system"] | |||
| if not prompt_config["system"]: | |||
| prompt_config["system"] = default_prompt["system"] | |||
| # if len(prompt_config["parameters"]) < 1: | |||
| # prompt_config["parameters"] = default_prompt["parameters"] | |||
| # for p in prompt_config["parameters"]: | |||
| @@ -63,16 +64,21 @@ def set_dialog(): | |||
| # else: prompt_config["parameters"].append(default_prompt["parameters"][0]) | |||
| for p in prompt_config["parameters"]: | |||
| if p["optional"]: continue | |||
| if p["optional"]: | |||
| continue | |||
| if prompt_config["system"].find("{%s}" % p["key"]) < 0: | |||
| return get_data_error_result(retmsg="Parameter '{}' is not used".format(p["key"])) | |||
| return get_data_error_result( | |||
| retmsg="Parameter '{}' is not used".format(p["key"])) | |||
| try: | |||
| e, tenant = TenantService.get_by_id(current_user.id) | |||
| if not e: return get_data_error_result(retmsg="Tenant not found!") | |||
| if not e: | |||
| return get_data_error_result(retmsg="Tenant not found!") | |||
| llm_id = req.get("llm_id", tenant.llm_id) | |||
| if not dialog_id: | |||
| if not req.get("kb_ids"):return get_data_error_result(retmsg="Fail! Please select knowledgebase!") | |||
| if not req.get("kb_ids"): | |||
| return get_data_error_result( | |||
| retmsg="Fail! Please select knowledgebase!") | |||
| dia = { | |||
| "id": get_uuid(), | |||
| "tenant_id": current_user.id, | |||
| @@ -86,17 +92,21 @@ def set_dialog(): | |||
| "similarity_threshold": similarity_threshold, | |||
| "vector_similarity_weight": vector_similarity_weight | |||
| } | |||
| if not DialogService.save(**dia): return get_data_error_result(retmsg="Fail to new a dialog!") | |||
| if not DialogService.save(**dia): | |||
| return get_data_error_result(retmsg="Fail to new a dialog!") | |||
| e, dia = DialogService.get_by_id(dia["id"]) | |||
| if not e: return get_data_error_result(retmsg="Fail to new a dialog!") | |||
| if not e: | |||
| return get_data_error_result(retmsg="Fail to new a dialog!") | |||
| return get_json_result(data=dia.to_json()) | |||
| else: | |||
| del req["dialog_id"] | |||
| if "kb_names" in req: del req["kb_names"] | |||
| if "kb_names" in req: | |||
| del req["kb_names"] | |||
| if not DialogService.update_by_id(dialog_id, req): | |||
| return get_data_error_result(retmsg="Dialog not found!") | |||
| e, dia = DialogService.get_by_id(dialog_id) | |||
| if not e: return get_data_error_result(retmsg="Fail to update a dialog!") | |||
| if not e: | |||
| return get_data_error_result(retmsg="Fail to update a dialog!") | |||
| dia = dia.to_dict() | |||
| dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"]) | |||
| return get_json_result(data=dia) | |||
| @@ -110,7 +120,8 @@ def get(): | |||
| dialog_id = request.args["dialog_id"] | |||
| try: | |||
| e, dia = DialogService.get_by_id(dialog_id) | |||
| if not e: return get_data_error_result(retmsg="Dialog not found!") | |||
| if not e: | |||
| return get_data_error_result(retmsg="Dialog not found!") | |||
| dia = dia.to_dict() | |||
| dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"]) | |||
| return get_json_result(data=dia) | |||
| @@ -122,7 +133,8 @@ def get_kb_names(kb_ids): | |||
| ids, nms = [], [] | |||
| for kid in kb_ids: | |||
| e, kb = KnowledgebaseService.get_by_id(kid) | |||
| if not e or kb.status != StatusEnum.VALID.value: continue | |||
| if not e or kb.status != StatusEnum.VALID.value: | |||
| continue | |||
| ids.append(kid) | |||
| nms.append(kb.name) | |||
| return ids, nms | |||
| @@ -132,7 +144,11 @@ def get_kb_names(kb_ids): | |||
| @login_required | |||
| def list(): | |||
| try: | |||
| diags = DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value, reverse=True, order_by=DialogService.model.create_time) | |||
| diags = DialogService.query( | |||
| tenant_id=current_user.id, | |||
| status=StatusEnum.VALID.value, | |||
| reverse=True, | |||
| order_by=DialogService.model.create_time) | |||
| diags = [d.to_dict() for d in diags] | |||
| for d in diags: | |||
| d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"]) | |||
| @@ -147,7 +163,8 @@ def list(): | |||
| def rm(): | |||
| req = request.json | |||
| try: | |||
| DialogService.update_many_by_id([{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]]) | |||
| DialogService.update_many_by_id( | |||
| [{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]]) | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -57,6 +57,9 @@ def upload(): | |||
| if not e: | |||
| return get_data_error_result( | |||
| retmsg="Can't find this knowledgebase!") | |||
| if DocumentService.get_doc_count(kb.tenant_id) >= 128: | |||
| return get_data_error_result( | |||
| retmsg="Exceed the maximum file number of a free user!") | |||
| filename = duplicate_name( | |||
| DocumentService.query, | |||
| @@ -215,9 +218,11 @@ def rm(): | |||
| tenant_id = DocumentService.get_tenant_id(req["doc_id"]) | |||
| if not tenant_id: | |||
| return get_data_error_result(retmsg="Tenant not found!") | |||
| ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)) | |||
| ELASTICSEARCH.deleteByQuery( | |||
| Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)) | |||
| DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, 0) | |||
| DocumentService.increment_chunk_num( | |||
| doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, 0) | |||
| if not DocumentService.delete(doc): | |||
| return get_data_error_result( | |||
| retmsg="Database error (Document removal)!") | |||
| @@ -245,7 +250,8 @@ def run(): | |||
| tenant_id = DocumentService.get_tenant_id(id) | |||
| if not tenant_id: | |||
| return get_data_error_result(retmsg="Tenant not found!") | |||
| ELASTICSEARCH.deleteByQuery(Q("match", doc_id=id), idxnm=search.index_name(tenant_id)) | |||
| ELASTICSEARCH.deleteByQuery( | |||
| Q("match", doc_id=id), idxnm=search.index_name(tenant_id)) | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| @@ -261,7 +267,8 @@ def rename(): | |||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | |||
| if not e: | |||
| return get_data_error_result(retmsg="Document not found!") | |||
| if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix: | |||
| if pathlib.Path(req["name"].lower()).suffix != pathlib.Path( | |||
| doc.name.lower()).suffix: | |||
| return get_json_result( | |||
| data=False, | |||
| retmsg="The extension of file can't be changed", | |||
| @@ -294,7 +301,10 @@ def get(doc_id): | |||
| if doc.type == FileType.VISUAL.value: | |||
| response.headers.set('Content-Type', 'image/%s' % ext.group(1)) | |||
| else: | |||
| response.headers.set('Content-Type', 'application/%s' % ext.group(1)) | |||
| response.headers.set( | |||
| 'Content-Type', | |||
| 'application/%s' % | |||
| ext.group(1)) | |||
| return response | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -313,9 +323,11 @@ def change_parser(): | |||
| if "parser_config" in req: | |||
| if req["parser_config"] == doc.parser_config: | |||
| return get_json_result(data=True) | |||
| else: return get_json_result(data=True) | |||
| else: | |||
| return get_json_result(data=True) | |||
| if doc.type == FileType.VISUAL or re.search(r"\.(ppt|pptx|pages)$", doc.name): | |||
| if doc.type == FileType.VISUAL or re.search( | |||
| r"\.(ppt|pptx|pages)$", doc.name): | |||
| return get_data_error_result(retmsg="Not supported yet!") | |||
| e = DocumentService.update_by_id(doc.id, | |||
| @@ -332,7 +344,8 @@ def change_parser(): | |||
| tenant_id = DocumentService.get_tenant_id(req["doc_id"]) | |||
| if not tenant_id: | |||
| return get_data_error_result(retmsg="Tenant not found!") | |||
| ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)) | |||
| ELASTICSEARCH.deleteByQuery( | |||
| Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)) | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| @@ -33,15 +33,21 @@ from api.utils.api_utils import get_json_result | |||
| def create(): | |||
| req = request.json | |||
| req["name"] = req["name"].strip() | |||
| req["name"] = duplicate_name(KnowledgebaseService.query, name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value) | |||
| req["name"] = duplicate_name( | |||
| KnowledgebaseService.query, | |||
| name=req["name"], | |||
| tenant_id=current_user.id, | |||
| status=StatusEnum.VALID.value) | |||
| try: | |||
| req["id"] = get_uuid() | |||
| req["tenant_id"] = current_user.id | |||
| req["created_by"] = current_user.id | |||
| e, t = TenantService.get_by_id(current_user.id) | |||
| if not e: return get_data_error_result(retmsg="Tenant not found.") | |||
| if not e: | |||
| return get_data_error_result(retmsg="Tenant not found.") | |||
| req["embd_id"] = t.embd_id | |||
| if not KnowledgebaseService.save(**req): return get_data_error_result() | |||
| if not KnowledgebaseService.save(**req): | |||
| return get_data_error_result() | |||
| return get_json_result(data={"kb_id": req["id"]}) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -54,21 +60,29 @@ def update(): | |||
| req = request.json | |||
| req["name"] = req["name"].strip() | |||
| try: | |||
| if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]): | |||
| return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR) | |||
| if not KnowledgebaseService.query( | |||
| created_by=current_user.id, id=req["kb_id"]): | |||
| return get_json_result( | |||
| data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR) | |||
| e, kb = KnowledgebaseService.get_by_id(req["kb_id"]) | |||
| if not e: return get_data_error_result(retmsg="Can't find this knowledgebase!") | |||
| if not e: | |||
| return get_data_error_result( | |||
| retmsg="Can't find this knowledgebase!") | |||
| if req["name"].lower() != kb.name.lower() \ | |||
| and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value))>1: | |||
| return get_data_error_result(retmsg="Duplicated knowledgebase name.") | |||
| and len(KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) > 1: | |||
| return get_data_error_result( | |||
| retmsg="Duplicated knowledgebase name.") | |||
| del req["kb_id"] | |||
| if not KnowledgebaseService.update_by_id(kb.id, req): return get_data_error_result() | |||
| if not KnowledgebaseService.update_by_id(kb.id, req): | |||
| return get_data_error_result() | |||
| e, kb = KnowledgebaseService.get_by_id(kb.id) | |||
| if not e: return get_data_error_result(retmsg="Database error (Knowledgebase rename)!") | |||
| if not e: | |||
| return get_data_error_result( | |||
| retmsg="Database error (Knowledgebase rename)!") | |||
| return get_json_result(data=kb.to_json()) | |||
| except Exception as e: | |||
| @@ -81,7 +95,9 @@ def detail(): | |||
| kb_id = request.args["kb_id"] | |||
| try: | |||
| kb = KnowledgebaseService.get_detail(kb_id) | |||
| if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!") | |||
| if not kb: | |||
| return get_data_error_result( | |||
| retmsg="Can't find this knowledgebase!") | |||
| return get_json_result(data=kb) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -96,7 +112,8 @@ def list(): | |||
| desc = request.args.get("desc", True) | |||
| try: | |||
| tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) | |||
| kbs = KnowledgebaseService.get_by_tenant_ids([m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc) | |||
| kbs = KnowledgebaseService.get_by_tenant_ids( | |||
| [m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc) | |||
| return get_json_result(data=kbs) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -108,10 +125,15 @@ def list(): | |||
| def rm(): | |||
| req = request.json | |||
| try: | |||
| if not KnowledgebaseService.query(created_by=current_user.id, id=req["kb_id"]): | |||
| return get_json_result(data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR) | |||
| if not KnowledgebaseService.update_by_id(req["kb_id"], {"status": StatusEnum.INVALID.value}): return get_data_error_result(retmsg="Database error (Knowledgebase removal)!") | |||
| if not KnowledgebaseService.query( | |||
| created_by=current_user.id, id=req["kb_id"]): | |||
| return get_json_result( | |||
| data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR) | |||
| if not KnowledgebaseService.update_by_id( | |||
| req["kb_id"], {"status": StatusEnum.INVALID.value}): | |||
| return get_data_error_result( | |||
| retmsg="Database error (Knowledgebase removal)!") | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| return server_error_response(e) | |||
| @@ -48,30 +48,42 @@ def set_api_key(): | |||
| req["api_key"], llm.llm_name) | |||
| try: | |||
| arr, tc = mdl.encode(["Test if the api key is available"]) | |||
| if len(arr[0]) == 0 or tc ==0: raise Exception("Fail") | |||
| if len(arr[0]) == 0 or tc == 0: | |||
| raise Exception("Fail") | |||
| except Exception as e: | |||
| msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." | |||
| elif not chat_passed and llm.model_type == LLMType.CHAT.value: | |||
| mdl = ChatModel[factory]( | |||
| req["api_key"], llm.llm_name) | |||
| try: | |||
| m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9}) | |||
| if not tc: raise Exception(m) | |||
| m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], { | |||
| "temperature": 0.9}) | |||
| if not tc: | |||
| raise Exception(m) | |||
| chat_passed = True | |||
| except Exception as e: | |||
| msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(e) | |||
| msg += f"\nFail to access model({llm.llm_name}) using this api key." + str( | |||
| e) | |||
| if msg: return get_data_error_result(retmsg=msg) | |||
| if msg: | |||
| return get_data_error_result(retmsg=msg) | |||
| llm = { | |||
| "api_key": req["api_key"] | |||
| } | |||
| for n in ["model_type", "llm_name"]: | |||
| if n in req: llm[n] = req[n] | |||
| if n in req: | |||
| llm[n] = req[n] | |||
| if not TenantLLMService.filter_update([TenantLLM.tenant_id==current_user.id, TenantLLM.llm_factory==factory], llm): | |||
| if not TenantLLMService.filter_update( | |||
| [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == factory], llm): | |||
| for llm in LLMService.query(fid=factory): | |||
| TenantLLMService.save(tenant_id=current_user.id, llm_factory=factory, llm_name=llm.llm_name, model_type=llm.model_type, api_key=req["api_key"]) | |||
| TenantLLMService.save( | |||
| tenant_id=current_user.id, | |||
| llm_factory=factory, | |||
| llm_name=llm.llm_name, | |||
| model_type=llm.model_type, | |||
| api_key=req["api_key"]) | |||
| return get_json_result(data=True) | |||
| @@ -105,17 +117,19 @@ def list(): | |||
| objs = TenantLLMService.query(tenant_id=current_user.id) | |||
| facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key]) | |||
| llms = LLMService.get_all() | |||
| llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value] | |||
| llms = [m.to_dict() | |||
| for m in llms if m.status == StatusEnum.VALID.value] | |||
| for m in llms: | |||
| m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" | |||
| res = {} | |||
| for m in llms: | |||
| if model_type and m["model_type"] != model_type: continue | |||
| if m["fid"] not in res: res[m["fid"]] = [] | |||
| if model_type and m["model_type"] != model_type: | |||
| continue | |||
| if m["fid"] not in res: | |||
| res[m["fid"]] = [] | |||
| res[m["fid"]].append(m) | |||
| return get_json_result(data=res) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -40,13 +40,16 @@ def login(): | |||
| email = request.json.get('email', "") | |||
| users = UserService.query(email=email) | |||
| if not users: return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!') | |||
| if not users: | |||
| return get_json_result( | |||
| data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!') | |||
| password = request.json.get('password') | |||
| try: | |||
| password = decrypt(password) | |||
| except: | |||
| return get_json_result(data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password') | |||
| except BaseException: | |||
| return get_json_result( | |||
| data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password') | |||
| user = UserService.query_user(email, password) | |||
| if user: | |||
| @@ -57,7 +60,8 @@ def login(): | |||
| msg = "Welcome back!" | |||
| return cors_reponse(data=response_data, auth=user.get_id(), retmsg=msg) | |||
| else: | |||
| return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Email and Password do not match!') | |||
| return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, | |||
| retmsg='Email and Password do not match!') | |||
| @manager.route('/github_callback', methods=['GET']) | |||
| @@ -65,7 +69,7 @@ def github_callback(): | |||
| import requests | |||
| res = requests.post(GITHUB_OAUTH.get("url"), data={ | |||
| "client_id": GITHUB_OAUTH.get("client_id"), | |||
| "client_secret": GITHUB_OAUTH.get("secret_key"), | |||
| "client_secret": GITHUB_OAUTH.get("secret_key"), | |||
| "code": request.args.get('code') | |||
| }, headers={"Accept": "application/json"}) | |||
| res = res.json() | |||
| @@ -96,15 +100,17 @@ def github_callback(): | |||
| "last_login_time": get_format_time(), | |||
| "is_superuser": False, | |||
| }) | |||
| if not users: raise Exception('Register user failure.') | |||
| if len(users) > 1: raise Exception('Same E-mail exist!') | |||
| if not users: | |||
| raise Exception('Register user failure.') | |||
| if len(users) > 1: | |||
| raise Exception('Same E-mail exist!') | |||
| user = users[0] | |||
| login_user(user) | |||
| return redirect("/?auth=%s"%user.get_id()) | |||
| return redirect("/?auth=%s" % user.get_id()) | |||
| except Exception as e: | |||
| rollback_user_registration(user_id) | |||
| stat_logger.exception(e) | |||
| return redirect("/?error=%s"%str(e)) | |||
| return redirect("/?error=%s" % str(e)) | |||
| user = users[0] | |||
| user.access_token = get_uuid() | |||
| login_user(user) | |||
| @@ -114,11 +120,18 @@ def github_callback(): | |||
| def user_info_from_github(access_token): | |||
| import requests | |||
| headers = {"Accept": "application/json", 'Authorization': f"token {access_token}"} | |||
| res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers) | |||
| headers = {"Accept": "application/json", | |||
| 'Authorization': f"token {access_token}"} | |||
| res = requests.get( | |||
| f"https://api.github.com/user?access_token={access_token}", | |||
| headers=headers) | |||
| user_info = res.json() | |||
| email_info = requests.get(f"https://api.github.com/user/emails?access_token={access_token}", headers=headers).json() | |||
| user_info["email"] = next((email for email in email_info if email['primary'] == True), None)["email"] | |||
| email_info = requests.get( | |||
| f"https://api.github.com/user/emails?access_token={access_token}", | |||
| headers=headers).json() | |||
| user_info["email"] = next( | |||
| (email for email in email_info if email['primary'] == True), | |||
| None)["email"] | |||
| return user_info | |||
| @@ -138,13 +151,18 @@ def setting_user(): | |||
| request_data = request.json | |||
| if request_data.get("password"): | |||
| new_password = request_data.get("new_password") | |||
| if not check_password_hash(current_user.password, decrypt(request_data["password"])): | |||
| return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!') | |||
| if not check_password_hash( | |||
| current_user.password, decrypt(request_data["password"])): | |||
| return get_json_result( | |||
| data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!') | |||
| if new_password: update_dict["password"] = generate_password_hash(decrypt(new_password)) | |||
| if new_password: | |||
| update_dict["password"] = generate_password_hash( | |||
| decrypt(new_password)) | |||
| for k in request_data.keys(): | |||
| if k in ["password", "new_password"]:continue | |||
| if k in ["password", "new_password"]: | |||
| continue | |||
| update_dict[k] = request_data[k] | |||
| try: | |||
| @@ -152,7 +170,8 @@ def setting_user(): | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| stat_logger.exception(e) | |||
| return get_json_result(data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR) | |||
| return get_json_result( | |||
| data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR) | |||
| @manager.route("/info", methods=["GET"]) | |||
| @@ -173,11 +192,11 @@ def rollback_user_registration(user_id): | |||
| except Exception as e: | |||
| pass | |||
| try: | |||
| TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute() | |||
| TenantLLM.delete().where(TenantLLM.tenant_id == user_id).excute() | |||
| except Exception as e: | |||
| pass | |||
| def user_register(user_id, user): | |||
| user["id"] = user_id | |||
| tenant = { | |||
| @@ -197,9 +216,14 @@ def user_register(user_id, user): | |||
| } | |||
| tenant_llm = [] | |||
| for llm in LLMService.query(fid=LLM_FACTORY): | |||
| tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY}) | |||
| if not UserService.save(**user):return | |||
| tenant_llm.append({"tenant_id": user_id, | |||
| "llm_factory": LLM_FACTORY, | |||
| "llm_name": llm.llm_name, | |||
| "model_type": llm.model_type, | |||
| "api_key": API_KEY}) | |||
| if not UserService.save(**user): | |||
| return | |||
| TenantService.insert(**tenant) | |||
| UserTenantService.insert(**usr_tenant) | |||
| TenantLLMService.insert_many(tenant_llm) | |||
| @@ -211,7 +235,8 @@ def user_register(user_id, user): | |||
| def user_add(): | |||
| req = request.json | |||
| if UserService.query(email=req["email"]): | |||
| return get_json_result(data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR) | |||
| return get_json_result( | |||
| data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR) | |||
| if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,4}$", req["email"]): | |||
| return get_json_result(data=False, retmsg=f'Invaliad e-mail: {req["email"]}!', | |||
| retcode=RetCode.OPERATING_ERROR) | |||
| @@ -229,16 +254,19 @@ def user_add(): | |||
| user_id = get_uuid() | |||
| try: | |||
| users = user_register(user_id, user_dict) | |||
| if not users: raise Exception('Register user failure.') | |||
| if len(users) > 1: raise Exception('Same E-mail exist!') | |||
| if not users: | |||
| raise Exception('Register user failure.') | |||
| if len(users) > 1: | |||
| raise Exception('Same E-mail exist!') | |||
| user = users[0] | |||
| login_user(user) | |||
| return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!") | |||
| return cors_reponse(data=user.to_json(), | |||
| auth=user.get_id(), retmsg="Welcome aboard!") | |||
| except Exception as e: | |||
| rollback_user_registration(user_id) | |||
| stat_logger.exception(e) | |||
| return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR) | |||
| return get_json_result( | |||
| data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR) | |||
| @manager.route("/tenant_info", methods=["GET"]) | |||
| @@ -50,7 +50,13 @@ def singleton(cls, *args, **kw): | |||
| CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField} | |||
| AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {"create", "start", "end", "update", "read_access", "write_access"} | |||
| AUTO_DATE_TIMESTAMP_FIELD_PREFIX = { | |||
| "create", | |||
| "start", | |||
| "end", | |||
| "update", | |||
| "read_access", | |||
| "write_access"} | |||
| class LongTextField(TextField): | |||
| @@ -73,7 +79,8 @@ class JSONField(LongTextField): | |||
| def python_value(self, value): | |||
| if not value: | |||
| return self.default_value | |||
| return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) | |||
| return utils.json_loads( | |||
| value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) | |||
| class ListField(JSONField): | |||
| @@ -81,7 +88,8 @@ class ListField(JSONField): | |||
| class SerializedField(LongTextField): | |||
| def __init__(self, serialized_type=SerializedType.PICKLE, object_hook=None, object_pairs_hook=None, **kwargs): | |||
| def __init__(self, serialized_type=SerializedType.PICKLE, | |||
| object_hook=None, object_pairs_hook=None, **kwargs): | |||
| self._serialized_type = serialized_type | |||
| self._object_hook = object_hook | |||
| self._object_pairs_hook = object_pairs_hook | |||
| @@ -95,7 +103,8 @@ class SerializedField(LongTextField): | |||
| return None | |||
| return utils.json_dumps(value, with_type=True) | |||
| else: | |||
| raise ValueError(f"the serialized type {self._serialized_type} is not supported") | |||
| raise ValueError( | |||
| f"the serialized type {self._serialized_type} is not supported") | |||
| def python_value(self, value): | |||
| if self._serialized_type == SerializedType.PICKLE: | |||
| @@ -103,9 +112,11 @@ class SerializedField(LongTextField): | |||
| elif self._serialized_type == SerializedType.JSON: | |||
| if value is None: | |||
| return {} | |||
| return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) | |||
| return utils.json_loads( | |||
| value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook) | |||
| else: | |||
| raise ValueError(f"the serialized type {self._serialized_type} is not supported") | |||
| raise ValueError( | |||
| f"the serialized type {self._serialized_type} is not supported") | |||
| def is_continuous_field(cls: typing.Type) -> bool: | |||
| @@ -150,7 +161,8 @@ class BaseModel(Model): | |||
| model_dict = self.__dict__['__data__'] | |||
| if not only_primary_with: | |||
| return {remove_field_name_prefix(k): v for k, v in model_dict.items()} | |||
| return {remove_field_name_prefix( | |||
| k): v for k, v in model_dict.items()} | |||
| human_model_dict = {} | |||
| for k in self._meta.primary_key.field_names: | |||
| @@ -184,17 +196,22 @@ class BaseModel(Model): | |||
| if is_continuous_field(type(getattr(cls, attr_name))): | |||
| if len(f_v) == 2: | |||
| for i, v in enumerate(f_v): | |||
| if isinstance(v, str) and f_n in auto_date_timestamp_field(): | |||
| if isinstance( | |||
| v, str) and f_n in auto_date_timestamp_field(): | |||
| # time type: %Y-%m-%d %H:%M:%S | |||
| f_v[i] = utils.date_string_to_timestamp(v) | |||
| lt_value = f_v[0] | |||
| gt_value = f_v[1] | |||
| if lt_value is not None and gt_value is not None: | |||
| filters.append(cls.getter_by(attr_name).between(lt_value, gt_value)) | |||
| filters.append( | |||
| cls.getter_by(attr_name).between( | |||
| lt_value, gt_value)) | |||
| elif lt_value is not None: | |||
| filters.append(operator.attrgetter(attr_name)(cls) >= lt_value) | |||
| filters.append( | |||
| operator.attrgetter(attr_name)(cls) >= lt_value) | |||
| elif gt_value is not None: | |||
| filters.append(operator.attrgetter(attr_name)(cls) <= gt_value) | |||
| filters.append( | |||
| operator.attrgetter(attr_name)(cls) <= gt_value) | |||
| else: | |||
| filters.append(operator.attrgetter(attr_name)(cls) << f_v) | |||
| else: | |||
| @@ -205,9 +222,11 @@ class BaseModel(Model): | |||
| if not order_by or not hasattr(cls, f"{order_by}"): | |||
| order_by = "create_time" | |||
| if reverse is True: | |||
| query_records = query_records.order_by(cls.getter_by(f"{order_by}").desc()) | |||
| query_records = query_records.order_by( | |||
| cls.getter_by(f"{order_by}").desc()) | |||
| elif reverse is False: | |||
| query_records = query_records.order_by(cls.getter_by(f"{order_by}").asc()) | |||
| query_records = query_records.order_by( | |||
| cls.getter_by(f"{order_by}").asc()) | |||
| return [query_record for query_record in query_records] | |||
| else: | |||
| return [] | |||
| @@ -215,7 +234,8 @@ class BaseModel(Model): | |||
| @classmethod | |||
| def insert(cls, __data=None, **insert): | |||
| if isinstance(__data, dict) and __data: | |||
| __data[cls._meta.combined["create_time"]] = utils.current_timestamp() | |||
| __data[cls._meta.combined["create_time"] | |||
| ] = utils.current_timestamp() | |||
| if insert: | |||
| insert["create_time"] = utils.current_timestamp() | |||
| @@ -228,7 +248,8 @@ class BaseModel(Model): | |||
| if not normalized: | |||
| return {} | |||
| normalized[cls._meta.combined["update_time"]] = utils.current_timestamp() | |||
| normalized[cls._meta.combined["update_time"] | |||
| ] = utils.current_timestamp() | |||
| for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX: | |||
| if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \ | |||
| @@ -241,7 +262,8 @@ class BaseModel(Model): | |||
| class JsonSerializedField(SerializedField): | |||
| def __init__(self, object_hook=utils.from_dict_hook, object_pairs_hook=None, **kwargs): | |||
| def __init__(self, object_hook=utils.from_dict_hook, | |||
| object_pairs_hook=None, **kwargs): | |||
| super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook, | |||
| object_pairs_hook=object_pairs_hook, **kwargs) | |||
| @@ -251,7 +273,8 @@ class BaseDataBase: | |||
| def __init__(self): | |||
| database_config = DATABASE.copy() | |||
| db_name = database_config.pop("name") | |||
| self.database_connection = PooledMySQLDatabase(db_name, **database_config) | |||
| self.database_connection = PooledMySQLDatabase( | |||
| db_name, **database_config) | |||
| stat_logger.info('init mysql database on cluster mode successfully') | |||
| @@ -263,7 +286,8 @@ class DatabaseLock: | |||
| def lock(self): | |||
| # SQL parameters only support %s format placeholders | |||
| cursor = self.db.execute_sql("SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout)) | |||
| cursor = self.db.execute_sql( | |||
| "SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout)) | |||
| ret = cursor.fetchone() | |||
| if ret[0] == 0: | |||
| raise Exception(f'acquire mysql lock {self.lock_name} timeout') | |||
| @@ -273,10 +297,12 @@ class DatabaseLock: | |||
| raise Exception(f'failed to acquire lock {self.lock_name}') | |||
| def unlock(self): | |||
| cursor = self.db.execute_sql("SELECT RELEASE_LOCK(%s)", (self.lock_name,)) | |||
| cursor = self.db.execute_sql( | |||
| "SELECT RELEASE_LOCK(%s)", (self.lock_name,)) | |||
| ret = cursor.fetchone() | |||
| if ret[0] == 0: | |||
| raise Exception(f'mysql lock {self.lock_name} was not established by this thread') | |||
| raise Exception( | |||
| f'mysql lock {self.lock_name} was not established by this thread') | |||
| elif ret[0] == 1: | |||
| return True | |||
| else: | |||
| @@ -350,17 +376,37 @@ class User(DataBaseModel, UserMixin): | |||
| access_token = CharField(max_length=255, null=True) | |||
| nickname = CharField(max_length=100, null=False, help_text="nicky name") | |||
| password = CharField(max_length=255, null=True, help_text="password") | |||
| email = CharField(max_length=255, null=False, help_text="email", index=True) | |||
| email = CharField( | |||
| max_length=255, | |||
| null=False, | |||
| help_text="email", | |||
| index=True) | |||
| avatar = TextField(null=True, help_text="avatar base64 string") | |||
| language = CharField(max_length=32, null=True, help_text="English|Chinese", default="Chinese") | |||
| color_schema = CharField(max_length=32, null=True, help_text="Bright|Dark", default="Bright") | |||
| timezone = CharField(max_length=64, null=True, help_text="Timezone", default="UTC+8\tAsia/Shanghai") | |||
| language = CharField( | |||
| max_length=32, | |||
| null=True, | |||
| help_text="English|Chinese", | |||
| default="Chinese") | |||
| color_schema = CharField( | |||
| max_length=32, | |||
| null=True, | |||
| help_text="Bright|Dark", | |||
| default="Bright") | |||
| timezone = CharField( | |||
| max_length=64, | |||
| null=True, | |||
| help_text="Timezone", | |||
| default="UTC+8\tAsia/Shanghai") | |||
| last_login_time = DateTimeField(null=True) | |||
| is_authenticated = CharField(max_length=1, null=False, default="1") | |||
| is_active = CharField(max_length=1, null=False, default="1") | |||
| is_anonymous = CharField(max_length=1, null=False, default="0") | |||
| login_channel = CharField(null=True, help_text="from which user login") | |||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | |||
| status = CharField( | |||
| max_length=1, | |||
| null=True, | |||
| help_text="is it validate(0: wasted,1: validate)", | |||
| default="1") | |||
| is_superuser = BooleanField(null=True, help_text="is root", default=False) | |||
| def __str__(self): | |||
| @@ -379,12 +425,28 @@ class Tenant(DataBaseModel): | |||
| name = CharField(max_length=100, null=True, help_text="Tenant name") | |||
| public_key = CharField(max_length=255, null=True) | |||
| llm_id = CharField(max_length=128, null=False, help_text="default llm ID") | |||
| embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID") | |||
| asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID") | |||
| img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID") | |||
| parser_ids = CharField(max_length=256, null=False, help_text="document processors") | |||
| embd_id = CharField( | |||
| max_length=128, | |||
| null=False, | |||
| help_text="default embedding model ID") | |||
| asr_id = CharField( | |||
| max_length=128, | |||
| null=False, | |||
| help_text="default ASR model ID") | |||
| img2txt_id = CharField( | |||
| max_length=128, | |||
| null=False, | |||
| help_text="default image to text model ID") | |||
| parser_ids = CharField( | |||
| max_length=256, | |||
| null=False, | |||
| help_text="document processors") | |||
| credit = IntegerField(default=512) | |||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | |||
| status = CharField( | |||
| max_length=1, | |||
| null=True, | |||
| help_text="is it validate(0: wasted,1: validate)", | |||
| default="1") | |||
| class Meta: | |||
| db_table = "tenant" | |||
| @@ -396,7 +458,11 @@ class UserTenant(DataBaseModel): | |||
| tenant_id = CharField(max_length=32, null=False) | |||
| role = CharField(max_length=32, null=False, help_text="UserTenantRole") | |||
| invited_by = CharField(max_length=32, null=False) | |||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | |||
| status = CharField( | |||
| max_length=1, | |||
| null=True, | |||
| help_text="is it validate(0: wasted,1: validate)", | |||
| default="1") | |||
| class Meta: | |||
| db_table = "user_tenant" | |||
| @@ -408,17 +474,32 @@ class InvitationCode(DataBaseModel): | |||
| visit_time = DateTimeField(null=True) | |||
| user_id = CharField(max_length=32, null=True) | |||
| tenant_id = CharField(max_length=32, null=True) | |||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | |||
| status = CharField( | |||
| max_length=1, | |||
| null=True, | |||
| help_text="is it validate(0: wasted,1: validate)", | |||
| default="1") | |||
| class Meta: | |||
| db_table = "invitation_code" | |||
| class LLMFactories(DataBaseModel): | |||
| name = CharField(max_length=128, null=False, help_text="LLM factory name", primary_key=True) | |||
| name = CharField( | |||
| max_length=128, | |||
| null=False, | |||
| help_text="LLM factory name", | |||
| primary_key=True) | |||
| logo = TextField(null=True, help_text="llm logo base64") | |||
| tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, ASR") | |||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | |||
| tags = CharField( | |||
| max_length=255, | |||
| null=False, | |||
| help_text="LLM, Text Embedding, Image2Text, ASR") | |||
| status = CharField( | |||
| max_length=1, | |||
| null=True, | |||
| help_text="is it validate(0: wasted,1: validate)", | |||
| default="1") | |||
| def __str__(self): | |||
| return self.name | |||
| @@ -429,12 +510,27 @@ class LLMFactories(DataBaseModel): | |||
| class LLM(DataBaseModel): | |||
| # LLMs dictionary | |||
| llm_name = CharField(max_length=128, null=False, help_text="LLM name", index=True, primary_key=True) | |||
| model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR") | |||
| llm_name = CharField( | |||
| max_length=128, | |||
| null=False, | |||
| help_text="LLM name", | |||
| index=True, | |||
| primary_key=True) | |||
| model_type = CharField( | |||
| max_length=128, | |||
| null=False, | |||
| help_text="LLM, Text Embedding, Image2Text, ASR") | |||
| fid = CharField(max_length=128, null=False, help_text="LLM factory id") | |||
| max_tokens = IntegerField(default=0) | |||
| tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...") | |||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | |||
| tags = CharField( | |||
| max_length=255, | |||
| null=False, | |||
| help_text="LLM, Text Embedding, Image2Text, Chat, 32k...") | |||
| status = CharField( | |||
| max_length=1, | |||
| null=True, | |||
| help_text="is it validate(0: wasted,1: validate)", | |||
| default="1") | |||
| def __str__(self): | |||
| return self.llm_name | |||
| @@ -445,9 +541,19 @@ class LLM(DataBaseModel): | |||
| class TenantLLM(DataBaseModel): | |||
| tenant_id = CharField(max_length=32, null=False) | |||
| llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name") | |||
| model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR") | |||
| llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="") | |||
| llm_factory = CharField( | |||
| max_length=128, | |||
| null=False, | |||
| help_text="LLM factory name") | |||
| model_type = CharField( | |||
| max_length=128, | |||
| null=True, | |||
| help_text="LLM, Text Embedding, Image2Text, ASR") | |||
| llm_name = CharField( | |||
| max_length=128, | |||
| null=True, | |||
| help_text="LLM name", | |||
| default="") | |||
| api_key = CharField(max_length=255, null=True, help_text="API KEY") | |||
| api_base = CharField(max_length=255, null=True, help_text="API Base") | |||
| used_tokens = IntegerField(default=0) | |||
| @@ -464,11 +570,26 @@ class Knowledgebase(DataBaseModel): | |||
| id = CharField(max_length=32, primary_key=True) | |||
| avatar = TextField(null=True, help_text="avatar base64 string") | |||
| tenant_id = CharField(max_length=32, null=False) | |||
| name = CharField(max_length=128, null=False, help_text="KB name", index=True) | |||
| language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese") | |||
| name = CharField( | |||
| max_length=128, | |||
| null=False, | |||
| help_text="KB name", | |||
| index=True) | |||
| language = CharField( | |||
| max_length=32, | |||
| null=True, | |||
| default="Chinese", | |||
| help_text="English|Chinese") | |||
| description = TextField(null=True, help_text="KB description") | |||
| embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID") | |||
| permission = CharField(max_length=16, null=False, help_text="me|team", default="me") | |||
| embd_id = CharField( | |||
| max_length=128, | |||
| null=False, | |||
| help_text="default embedding model ID") | |||
| permission = CharField( | |||
| max_length=16, | |||
| null=False, | |||
| help_text="me|team", | |||
| default="me") | |||
| created_by = CharField(max_length=32, null=False) | |||
| doc_num = IntegerField(default=0) | |||
| token_num = IntegerField(default=0) | |||
| @@ -476,9 +597,17 @@ class Knowledgebase(DataBaseModel): | |||
| similarity_threshold = FloatField(default=0.2) | |||
| vector_similarity_weight = FloatField(default=0.3) | |||
| parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value) | |||
| parser_config = JSONField(null=False, default={"pages":[[1,1000000]]}) | |||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | |||
| parser_id = CharField( | |||
| max_length=32, | |||
| null=False, | |||
| help_text="default parser ID", | |||
| default=ParserType.NAIVE.value) | |||
| parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]}) | |||
| status = CharField( | |||
| max_length=1, | |||
| null=True, | |||
| help_text="is it validate(0: wasted,1: validate)", | |||
| default="1") | |||
| def __str__(self): | |||
| return self.name | |||
| @@ -491,22 +620,50 @@ class Document(DataBaseModel): | |||
| id = CharField(max_length=32, primary_key=True) | |||
| thumbnail = TextField(null=True, help_text="thumbnail base64 string") | |||
| kb_id = CharField(max_length=256, null=False, index=True) | |||
| parser_id = CharField(max_length=32, null=False, help_text="default parser ID") | |||
| parser_config = JSONField(null=False, default={"pages":[[1,1000000]]}) | |||
| source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from") | |||
| parser_id = CharField( | |||
| max_length=32, | |||
| null=False, | |||
| help_text="default parser ID") | |||
| parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]}) | |||
| source_type = CharField( | |||
| max_length=128, | |||
| null=False, | |||
| default="local", | |||
| help_text="where dose this document from") | |||
| type = CharField(max_length=32, null=False, help_text="file extension") | |||
| created_by = CharField(max_length=32, null=False, help_text="who created it") | |||
| name = CharField(max_length=255, null=True, help_text="file name", index=True) | |||
| location = CharField(max_length=255, null=True, help_text="where dose it store") | |||
| created_by = CharField( | |||
| max_length=32, | |||
| null=False, | |||
| help_text="who created it") | |||
| name = CharField( | |||
| max_length=255, | |||
| null=True, | |||
| help_text="file name", | |||
| index=True) | |||
| location = CharField( | |||
| max_length=255, | |||
| null=True, | |||
| help_text="where dose it store") | |||
| size = IntegerField(default=0) | |||
| token_num = IntegerField(default=0) | |||
| chunk_num = IntegerField(default=0) | |||
| progress = FloatField(default=0) | |||
| progress_msg = TextField(null=True, help_text="process message", default="") | |||
| progress_msg = TextField( | |||
| null=True, | |||
| help_text="process message", | |||
| default="") | |||
| process_begin_at = DateTimeField(null=True) | |||
| process_duation = FloatField(default=0) | |||
| run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0") | |||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | |||
| run = CharField( | |||
| max_length=1, | |||
| null=True, | |||
| help_text="start to run processing or cancel.(1: run it; 2: cancel)", | |||
| default="0") | |||
| status = CharField( | |||
| max_length=1, | |||
| null=True, | |||
| help_text="is it validate(0: wasted,1: validate)", | |||
| default="1") | |||
| class Meta: | |||
| db_table = "document" | |||
| @@ -520,30 +677,52 @@ class Task(DataBaseModel): | |||
| begin_at = DateTimeField(null=True) | |||
| process_duation = FloatField(default=0) | |||
| progress = FloatField(default=0) | |||
| progress_msg = TextField(null=True, help_text="process message", default="") | |||
| progress_msg = TextField( | |||
| null=True, | |||
| help_text="process message", | |||
| default="") | |||
| class Dialog(DataBaseModel): | |||
| id = CharField(max_length=32, primary_key=True) | |||
| tenant_id = CharField(max_length=32, null=False) | |||
| name = CharField(max_length=255, null=True, help_text="dialog application name") | |||
| name = CharField( | |||
| max_length=255, | |||
| null=True, | |||
| help_text="dialog application name") | |||
| description = TextField(null=True, help_text="Dialog description") | |||
| icon = TextField(null=True, help_text="icon base64 string") | |||
| language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese") | |||
| language = CharField( | |||
| max_length=32, | |||
| null=True, | |||
| default="Chinese", | |||
| help_text="English|Chinese") | |||
| llm_id = CharField(max_length=32, null=False, help_text="default llm ID") | |||
| llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7, | |||
| "presence_penalty": 0.4, "max_tokens": 215}) | |||
| prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced") | |||
| prompt_type = CharField( | |||
| max_length=16, | |||
| null=False, | |||
| default="simple", | |||
| help_text="simple|advanced") | |||
| prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?", | |||
| "parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"}) | |||
| similarity_threshold = FloatField(default=0.2) | |||
| vector_similarity_weight = FloatField(default=0.3) | |||
| top_n = IntegerField(default=6) | |||
| do_refer = CharField(max_length=1, null=False, help_text="it needs to insert reference index into answer or not", default="1") | |||
| do_refer = CharField( | |||
| max_length=1, | |||
| null=False, | |||
| help_text="it needs to insert reference index into answer or not", | |||
| default="1") | |||
| kb_ids = JSONField(null=False, default=[]) | |||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | |||
| status = CharField( | |||
| max_length=1, | |||
| null=True, | |||
| help_text="is it validate(0: wasted,1: validate)", | |||
| default="1") | |||
| class Meta: | |||
| db_table = "dialog" | |||
| @@ -32,8 +32,7 @@ LOGGER = getLogger() | |||
| def bulk_insert_into_db(model, data_source, replace_on_conflict=False): | |||
| DB.create_tables([model]) | |||
| for i,data in enumerate(data_source): | |||
| for i, data in enumerate(data_source): | |||
| current_time = current_timestamp() + i | |||
| current_date = timestamp_to_date(current_time) | |||
| if 'create_time' not in data: | |||
| @@ -55,7 +54,8 @@ def bulk_insert_into_db(model, data_source, replace_on_conflict=False): | |||
| def get_dynamic_db_model(base, job_id): | |||
| return type(base.model(table_index=get_dynamic_tracking_table_index(job_id=job_id))) | |||
| return type(base.model( | |||
| table_index=get_dynamic_tracking_table_index(job_id=job_id))) | |||
| def get_dynamic_tracking_table_index(job_id): | |||
| @@ -86,7 +86,9 @@ supported_operators = { | |||
| '~': operator.inv, | |||
| } | |||
| def query_dict2expression(model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]): | |||
| def query_dict2expression( | |||
| model: Type[DataBaseModel], query: Dict[str, Union[bool, int, str, list, tuple]]): | |||
| expression = [] | |||
| for field, value in query.items(): | |||
| @@ -95,7 +97,10 @@ def query_dict2expression(model: Type[DataBaseModel], query: Dict[str, Union[boo | |||
| op, *val = value | |||
| field = getattr(model, f'f_{field}') | |||
| value = supported_operators[op](field, val[0]) if op in supported_operators else getattr(field, op)(*val) | |||
| value = supported_operators[op]( | |||
| field, val[0]) if op in supported_operators else getattr( | |||
| field, op)( | |||
| *val) | |||
| expression.append(value) | |||
| return reduce(operator.iand, expression) | |||
| @@ -61,45 +61,54 @@ def init_superuser(): | |||
| TenantService.insert(**tenant) | |||
| UserTenantService.insert(**usr_tenant) | |||
| TenantLLMService.insert_many(tenant_llm) | |||
| print("【INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.") | |||
| print( | |||
| "【INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.") | |||
| chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) | |||
| msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={}) | |||
| msg = chat_mdl.chat(system="", history=[ | |||
| {"role": "user", "content": "Hello!"}], gen_conf={}) | |||
| if msg.find("ERROR: ") == 0: | |||
| print("\33[91m【ERROR】\33[0m: ", "'{}' dosen't work. {}".format(tenant["llm_id"], msg)) | |||
| print( | |||
| "\33[91m【ERROR】\33[0m: ", | |||
| "'{}' dosen't work. {}".format( | |||
| tenant["llm_id"], | |||
| msg)) | |||
| embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"]) | |||
| v, c = embd_mdl.encode(["Hello!"]) | |||
| if c == 0: | |||
| print("\33[91m【ERROR】\33[0m:", " '{}' dosen't work!".format(tenant["embd_id"])) | |||
| print( | |||
| "\33[91m【ERROR】\33[0m:", | |||
| " '{}' dosen't work!".format( | |||
| tenant["embd_id"])) | |||
| factory_infos = [{ | |||
| "name": "OpenAI", | |||
| "logo": "", | |||
| "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", | |||
| "name": "OpenAI", | |||
| "logo": "", | |||
| "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", | |||
| "status": "1", | |||
| }, { | |||
| "name": "Tongyi-Qianwen", | |||
| "logo": "", | |||
| "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", | |||
| "status": "1", | |||
| }, { | |||
| "name": "ZHIPU-AI", | |||
| "logo": "", | |||
| "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", | |||
| "status": "1", | |||
| }, | |||
| { | |||
| "name": "Local", | |||
| "logo": "", | |||
| "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", | |||
| "status": "1", | |||
| },{ | |||
| "name": "Tongyi-Qianwen", | |||
| "logo": "", | |||
| "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", | |||
| "status": "1", | |||
| },{ | |||
| "name": "ZHIPU-AI", | |||
| "logo": "", | |||
| "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", | |||
| "status": "1", | |||
| }, | |||
| { | |||
| "name": "Local", | |||
| "logo": "", | |||
| "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", | |||
| "status": "1", | |||
| },{ | |||
| }, { | |||
| "name": "Moonshot", | |||
| "logo": "", | |||
| "tags": "LLM,TEXT EMBEDDING", | |||
| "status": "1", | |||
| } | |||
| "logo": "", | |||
| "tags": "LLM,TEXT EMBEDDING", | |||
| "status": "1", | |||
| } | |||
| # { | |||
| # "name": "文心一言", | |||
| # "logo": "", | |||
| @@ -107,6 +116,8 @@ factory_infos = [{ | |||
| # "status": "1", | |||
| # }, | |||
| ] | |||
| def init_llm_factory(): | |||
| llm_infos = [ | |||
| # ---------------------- OpenAI ------------------------ | |||
| @@ -116,37 +127,37 @@ def init_llm_factory(): | |||
| "tags": "LLM,CHAT,4K", | |||
| "max_tokens": 4096, | |||
| "model_type": LLMType.CHAT.value | |||
| },{ | |||
| }, { | |||
| "fid": factory_infos[0]["name"], | |||
| "llm_name": "gpt-3.5-turbo-16k-0613", | |||
| "tags": "LLM,CHAT,16k", | |||
| "max_tokens": 16385, | |||
| "model_type": LLMType.CHAT.value | |||
| },{ | |||
| }, { | |||
| "fid": factory_infos[0]["name"], | |||
| "llm_name": "text-embedding-ada-002", | |||
| "tags": "TEXT EMBEDDING,8K", | |||
| "max_tokens": 8191, | |||
| "model_type": LLMType.EMBEDDING.value | |||
| },{ | |||
| }, { | |||
| "fid": factory_infos[0]["name"], | |||
| "llm_name": "whisper-1", | |||
| "tags": "SPEECH2TEXT", | |||
| "max_tokens": 25*1024*1024, | |||
| "max_tokens": 25 * 1024 * 1024, | |||
| "model_type": LLMType.SPEECH2TEXT.value | |||
| },{ | |||
| }, { | |||
| "fid": factory_infos[0]["name"], | |||
| "llm_name": "gpt-4", | |||
| "tags": "LLM,CHAT,8K", | |||
| "max_tokens": 8191, | |||
| "model_type": LLMType.CHAT.value | |||
| },{ | |||
| }, { | |||
| "fid": factory_infos[0]["name"], | |||
| "llm_name": "gpt-4-32k", | |||
| "tags": "LLM,CHAT,32K", | |||
| "max_tokens": 32768, | |||
| "model_type": LLMType.CHAT.value | |||
| },{ | |||
| }, { | |||
| "fid": factory_infos[0]["name"], | |||
| "llm_name": "gpt-4-vision-preview", | |||
| "tags": "LLM,CHAT,IMAGE2TEXT", | |||
| @@ -160,31 +171,31 @@ def init_llm_factory(): | |||
| "tags": "LLM,CHAT,8K", | |||
| "max_tokens": 8191, | |||
| "model_type": LLMType.CHAT.value | |||
| },{ | |||
| }, { | |||
| "fid": factory_infos[1]["name"], | |||
| "llm_name": "qwen-plus", | |||
| "tags": "LLM,CHAT,32K", | |||
| "max_tokens": 32768, | |||
| "model_type": LLMType.CHAT.value | |||
| },{ | |||
| }, { | |||
| "fid": factory_infos[1]["name"], | |||
| "llm_name": "qwen-max-1201", | |||
| "tags": "LLM,CHAT,6K", | |||
| "max_tokens": 5899, | |||
| "model_type": LLMType.CHAT.value | |||
| },{ | |||
| }, { | |||
| "fid": factory_infos[1]["name"], | |||
| "llm_name": "text-embedding-v2", | |||
| "tags": "TEXT EMBEDDING,2K", | |||
| "max_tokens": 2048, | |||
| "model_type": LLMType.EMBEDDING.value | |||
| },{ | |||
| }, { | |||
| "fid": factory_infos[1]["name"], | |||
| "llm_name": "paraformer-realtime-8k-v1", | |||
| "tags": "SPEECH2TEXT", | |||
| "max_tokens": 25*1024*1024, | |||
| "max_tokens": 25 * 1024 * 1024, | |||
| "model_type": LLMType.SPEECH2TEXT.value | |||
| },{ | |||
| }, { | |||
| "fid": factory_infos[1]["name"], | |||
| "llm_name": "qwen-vl-max", | |||
| "tags": "LLM,CHAT,IMAGE2TEXT", | |||
| @@ -245,13 +256,13 @@ def init_llm_factory(): | |||
| "tags": "TEXT EMBEDDING,", | |||
| "max_tokens": 128 * 1000, | |||
| "model_type": LLMType.EMBEDDING.value | |||
| },{ | |||
| }, { | |||
| "fid": factory_infos[4]["name"], | |||
| "llm_name": "moonshot-v1-32k", | |||
| "tags": "LLM,CHAT,", | |||
| "max_tokens": 32768, | |||
| "model_type": LLMType.CHAT.value | |||
| },{ | |||
| }, { | |||
| "fid": factory_infos[4]["name"], | |||
| "llm_name": "moonshot-v1-128k", | |||
| "tags": "LLM,CHAT", | |||
| @@ -294,7 +305,6 @@ def init_web_data(): | |||
| print("init web data success:{}".format(time.time() - start_time)) | |||
| if __name__ == '__main__': | |||
| init_web_db() | |||
| init_web_data() | |||
| init_web_data() | |||
| @@ -18,4 +18,4 @@ import operator | |||
| import time | |||
| import typing | |||
| from api.utils.log_utils import sql_logger | |||
| import peewee | |||
| import peewee | |||
| @@ -18,10 +18,11 @@ class ReloadConfigBase: | |||
| def get_all(cls): | |||
| configs = {} | |||
| for k, v in cls.__dict__.items(): | |||
| if not callable(getattr(cls, k)) and not k.startswith("__") and not k.startswith("_"): | |||
| if not callable(getattr(cls, k)) and not k.startswith( | |||
| "__") and not k.startswith("_"): | |||
| configs[k] = v | |||
| return configs | |||
| @classmethod | |||
| def get(cls, config_name): | |||
| return getattr(cls, config_name) if hasattr(cls, config_name) else None | |||
| return getattr(cls, config_name) if hasattr(cls, config_name) else None | |||
| @@ -51,4 +51,4 @@ class RuntimeConfig(ReloadConfigBase): | |||
| @classmethod | |||
| def set_service_db(cls, service_db): | |||
| cls.SERVICE_DB = service_db | |||
| cls.SERVICE_DB = service_db | |||
| @@ -27,7 +27,8 @@ class CommonService: | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def query(cls, cols=None, reverse=None, order_by=None, **kwargs): | |||
| return cls.model.query(cols=cols, reverse=reverse, order_by=order_by, **kwargs) | |||
| return cls.model.query(cols=cols, reverse=reverse, | |||
| order_by=order_by, **kwargs) | |||
| @classmethod | |||
| @DB.connection_context() | |||
| @@ -40,9 +41,11 @@ class CommonService: | |||
| if not order_by or not hasattr(cls, order_by): | |||
| order_by = "create_time" | |||
| if reverse is True: | |||
| query_records = query_records.order_by(cls.model.getter_by(order_by).desc()) | |||
| query_records = query_records.order_by( | |||
| cls.model.getter_by(order_by).desc()) | |||
| elif reverse is False: | |||
| query_records = query_records.order_by(cls.model.getter_by(order_by).asc()) | |||
| query_records = query_records.order_by( | |||
| cls.model.getter_by(order_by).asc()) | |||
| return query_records | |||
| @classmethod | |||
| @@ -61,7 +64,7 @@ class CommonService: | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def save(cls, **kwargs): | |||
| #if "id" not in kwargs: | |||
| # if "id" not in kwargs: | |||
| # kwargs["id"] = get_uuid() | |||
| sample_obj = cls.model(**kwargs).save(force_insert=True) | |||
| return sample_obj | |||
| @@ -95,7 +98,8 @@ class CommonService: | |||
| for data in data_list: | |||
| data["update_time"] = current_timestamp() | |||
| data["update_date"] = datetime_format(datetime.now()) | |||
| cls.model.update(data).where(cls.model.id == data["id"]).execute() | |||
| cls.model.update(data).where( | |||
| cls.model.id == data["id"]).execute() | |||
| @classmethod | |||
| @DB.connection_context() | |||
| @@ -128,7 +132,6 @@ class CommonService: | |||
| def delete_by_id(cls, pid): | |||
| return cls.model.delete().where(cls.model.id == pid).execute() | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def filter_delete(cls, filters): | |||
| @@ -151,19 +154,30 @@ class CommonService: | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def filter_scope_list(cls, in_key, in_filters_list, filters=None, cols=None): | |||
| def filter_scope_list(cls, in_key, in_filters_list, | |||
| filters=None, cols=None): | |||
| in_filters_tuple_list = cls.cut_list(in_filters_list, 20) | |||
| if not filters: | |||
| filters = [] | |||
| res_list = [] | |||
| if cols: | |||
| for i in in_filters_tuple_list: | |||
| query_records = cls.model.select(*cols).where(getattr(cls.model, in_key).in_(i), *filters) | |||
| query_records = cls.model.select( | |||
| * | |||
| cols).where( | |||
| getattr( | |||
| cls.model, | |||
| in_key).in_(i), | |||
| * | |||
| filters) | |||
| if query_records: | |||
| res_list.extend([query_record for query_record in query_records]) | |||
| res_list.extend( | |||
| [query_record for query_record in query_records]) | |||
| else: | |||
| for i in in_filters_tuple_list: | |||
| query_records = cls.model.select().where(getattr(cls.model, in_key).in_(i), *filters) | |||
| query_records = cls.model.select().where( | |||
| getattr(cls.model, in_key).in_(i), *filters) | |||
| if query_records: | |||
| res_list.extend([query_record for query_record in query_records]) | |||
| return res_list | |||
| res_list.extend( | |||
| [query_record for query_record in query_records]) | |||
| return res_list | |||
| @@ -21,6 +21,5 @@ class DialogService(CommonService): | |||
| model = Dialog | |||
| class ConversationService(CommonService): | |||
| model = Conversation | |||
| @@ -72,7 +72,20 @@ class DocumentService(CommonService): | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_newly_uploaded(cls, tm, mod=0, comm=1, items_per_page=64): | |||
| fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.parser_config, cls.model.name, cls.model.type, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] | |||
| fields = [ | |||
| cls.model.id, | |||
| cls.model.kb_id, | |||
| cls.model.parser_id, | |||
| cls.model.parser_config, | |||
| cls.model.name, | |||
| cls.model.type, | |||
| cls.model.location, | |||
| cls.model.size, | |||
| Knowledgebase.tenant_id, | |||
| Tenant.embd_id, | |||
| Tenant.img2txt_id, | |||
| Tenant.asr_id, | |||
| cls.model.update_time] | |||
| docs = cls.model.select(*fields) \ | |||
| .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \ | |||
| .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\ | |||
| @@ -103,40 +116,64 @@ class DocumentService(CommonService): | |||
| @DB.connection_context() | |||
| def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation): | |||
| num = cls.model.update(token_num=cls.model.token_num + token_num, | |||
| chunk_num=cls.model.chunk_num + chunk_num, | |||
| process_duation=cls.model.process_duation+duation).where( | |||
| chunk_num=cls.model.chunk_num + chunk_num, | |||
| process_duation=cls.model.process_duation + duation).where( | |||
| cls.model.id == doc_id).execute() | |||
| if num == 0:raise LookupError("Document not found which is supposed to be there") | |||
| num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute() | |||
| if num == 0: | |||
| raise LookupError( | |||
| "Document not found which is supposed to be there") | |||
| num = Knowledgebase.update( | |||
| token_num=Knowledgebase.token_num + | |||
| token_num, | |||
| chunk_num=Knowledgebase.chunk_num + | |||
| chunk_num).where( | |||
| Knowledgebase.id == kb_id).execute() | |||
| return num | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_tenant_id(cls, doc_id): | |||
| docs = cls.model.select(Knowledgebase.tenant_id).join(Knowledgebase, on=(Knowledgebase.id == cls.model.kb_id)).where(cls.model.id == doc_id, Knowledgebase.status==StatusEnum.VALID.value) | |||
| docs = cls.model.select( | |||
| Knowledgebase.tenant_id).join( | |||
| Knowledgebase, on=( | |||
| Knowledgebase.id == cls.model.kb_id)).where( | |||
| cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) | |||
| docs = docs.dicts() | |||
| if not docs:return | |||
| if not docs: | |||
| return | |||
| return docs[0]["tenant_id"] | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_thumbnails(cls, docids): | |||
| fields = [cls.model.id, cls.model.thumbnail] | |||
| return list(cls.model.select(*fields).where(cls.model.id.in_(docids)).dicts()) | |||
| return list(cls.model.select( | |||
| *fields).where(cls.model.id.in_(docids)).dicts()) | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def update_parser_config(cls, id, config): | |||
| e, d = cls.get_by_id(id) | |||
| if not e:raise LookupError(f"Document({id}) not found.") | |||
| if not e: | |||
| raise LookupError(f"Document({id}) not found.") | |||
| def dfs_update(old, new): | |||
| for k,v in new.items(): | |||
| for k, v in new.items(): | |||
| if k not in old: | |||
| old[k] = v | |||
| continue | |||
| if isinstance(v, dict): | |||
| assert isinstance(old[k], dict) | |||
| dfs_update(old[k], v) | |||
| else: old[k] = v | |||
| else: | |||
| old[k] = v | |||
| dfs_update(d.parser_config, config) | |||
| cls.update_by_id(id, {"parser_config": d.parser_config}) | |||
| cls.update_by_id(id, {"parser_config": d.parser_config}) | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_doc_count(cls, tenant_id): | |||
| docs = cls.model.select(cls.model.id).join(Knowledgebase, | |||
| on=(Knowledgebase.id == cls.model.kb_id)).where( | |||
| Knowledgebase.tenant_id == tenant_id) | |||
| return len(docs) | |||
| @@ -55,7 +55,7 @@ class KnowledgebaseService(CommonService): | |||
| cls.model.chunk_num, | |||
| cls.model.parser_id, | |||
| cls.model.parser_config] | |||
| kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where( | |||
| kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where( | |||
| (cls.model.id == kb_id), | |||
| (cls.model.status == StatusEnum.VALID.value) | |||
| ) | |||
| @@ -69,9 +69,11 @@ class KnowledgebaseService(CommonService): | |||
| @DB.connection_context() | |||
| def update_parser_config(cls, id, config): | |||
| e, m = cls.get_by_id(id) | |||
| if not e:raise LookupError(f"knowledgebase({id}) not found.") | |||
| if not e: | |||
| raise LookupError(f"knowledgebase({id}) not found.") | |||
| def dfs_update(old, new): | |||
| for k,v in new.items(): | |||
| for k, v in new.items(): | |||
| if k not in old: | |||
| old[k] = v | |||
| continue | |||
| @@ -80,12 +82,12 @@ class KnowledgebaseService(CommonService): | |||
| dfs_update(old[k], v) | |||
| elif isinstance(v, list): | |||
| assert isinstance(old[k], list) | |||
| old[k] = list(set(old[k]+v)) | |||
| else: old[k] = v | |||
| old[k] = list(set(old[k] + v)) | |||
| else: | |||
| old[k] = v | |||
| dfs_update(m.parser_config, config) | |||
| cls.update_by_id(id, {"parser_config": m.parser_config}) | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_field_map(cls, ids): | |||
| @@ -94,4 +96,3 @@ class KnowledgebaseService(CommonService): | |||
| if k.parser_config and "field_map" in k.parser_config: | |||
| conf.update(k.parser_config["field_map"]) | |||
| return conf | |||
| @@ -59,7 +59,8 @@ class TenantLLMService(CommonService): | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese"): | |||
| def model_instance(cls, tenant_id, llm_type, | |||
| llm_name=None, lang="Chinese"): | |||
| e, tenant = TenantService.get_by_id(tenant_id) | |||
| if not e: | |||
| raise LookupError("Tenant not found") | |||
| @@ -126,29 +127,39 @@ class LLMBundle(object): | |||
| self.tenant_id = tenant_id | |||
| self.llm_type = llm_type | |||
| self.llm_name = llm_name | |||
| self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang) | |||
| assert self.mdl, "Can't find mole for {}/{}/{}".format(tenant_id, llm_type, llm_name) | |||
| self.mdl = TenantLLMService.model_instance( | |||
| tenant_id, llm_type, llm_name, lang=lang) | |||
| assert self.mdl, "Can't find mole for {}/{}/{}".format( | |||
| tenant_id, llm_type, llm_name) | |||
| def encode(self, texts: list, batch_size=32): | |||
| emd, used_tokens = self.mdl.encode(texts, batch_size) | |||
| if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): | |||
| database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id)) | |||
| if TenantLLMService.increase_usage( | |||
| self.tenant_id, self.llm_type, used_tokens): | |||
| database_logger.error( | |||
| "Can't update token usage for {}/EMBEDDING".format(self.tenant_id)) | |||
| return emd, used_tokens | |||
| def encode_queries(self, query: str): | |||
| emd, used_tokens = self.mdl.encode_queries(query) | |||
| if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): | |||
| database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id)) | |||
| if TenantLLMService.increase_usage( | |||
| self.tenant_id, self.llm_type, used_tokens): | |||
| database_logger.error( | |||
| "Can't update token usage for {}/EMBEDDING".format(self.tenant_id)) | |||
| return emd, used_tokens | |||
| def describe(self, image, max_tokens=300): | |||
| txt, used_tokens = self.mdl.describe(image, max_tokens) | |||
| if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): | |||
| database_logger.error("Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id)) | |||
| if not TenantLLMService.increase_usage( | |||
| self.tenant_id, self.llm_type, used_tokens): | |||
| database_logger.error( | |||
| "Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id)) | |||
| return txt | |||
| def chat(self, system, history, gen_conf): | |||
| txt, used_tokens = self.mdl.chat(system, history, gen_conf) | |||
| if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): | |||
| database_logger.error("Can't update token usage for {}/CHAT".format(self.tenant_id)) | |||
| if TenantLLMService.increase_usage( | |||
| self.tenant_id, self.llm_type, used_tokens, self.llm_name): | |||
| database_logger.error( | |||
| "Can't update token usage for {}/CHAT".format(self.tenant_id)) | |||
| return txt | |||
| @@ -54,7 +54,8 @@ class UserService(CommonService): | |||
| if "id" not in kwargs: | |||
| kwargs["id"] = get_uuid() | |||
| if "password" in kwargs: | |||
| kwargs["password"] = generate_password_hash(str(kwargs["password"])) | |||
| kwargs["password"] = generate_password_hash( | |||
| str(kwargs["password"])) | |||
| kwargs["create_time"] = current_timestamp() | |||
| kwargs["create_date"] = datetime_format(datetime.now()) | |||
| @@ -63,12 +64,12 @@ class UserService(CommonService): | |||
| obj = cls.model(**kwargs).save(force_insert=True) | |||
| return obj | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def delete_user(cls, user_ids, update_user_dict): | |||
| with DB.atomic(): | |||
| cls.model.update({"status": 0}).where(cls.model.id.in_(user_ids)).execute() | |||
| cls.model.update({"status": 0}).where( | |||
| cls.model.id.in_(user_ids)).execute() | |||
| @classmethod | |||
| @DB.connection_context() | |||
| @@ -77,7 +78,8 @@ class UserService(CommonService): | |||
| if user_dict: | |||
| user_dict["update_time"] = current_timestamp() | |||
| user_dict["update_date"] = datetime_format(datetime.now()) | |||
| cls.model.update(user_dict).where(cls.model.id == user_id).execute() | |||
| cls.model.update(user_dict).where( | |||
| cls.model.id == user_id).execute() | |||
| class TenantService(CommonService): | |||
| @@ -86,25 +88,42 @@ class TenantService(CommonService): | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_by_user_id(cls, user_id): | |||
| fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role] | |||
| return list(cls.model.select(*fields)\ | |||
| .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\ | |||
| .where(cls.model.status == StatusEnum.VALID.value).dicts()) | |||
| fields = [ | |||
| cls.model.id.alias("tenant_id"), | |||
| cls.model.name, | |||
| cls.model.llm_id, | |||
| cls.model.embd_id, | |||
| cls.model.asr_id, | |||
| cls.model.img2txt_id, | |||
| cls.model.parser_ids, | |||
| UserTenant.role] | |||
| return list(cls.model.select(*fields) | |||
| .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value))) | |||
| .where(cls.model.status == StatusEnum.VALID.value).dicts()) | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_joined_tenants_by_user_id(cls, user_id): | |||
| fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role] | |||
| return list(cls.model.select(*fields)\ | |||
| .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role==UserTenantRole.NORMAL.value)))\ | |||
| .where(cls.model.status == StatusEnum.VALID.value).dicts()) | |||
| fields = [ | |||
| cls.model.id.alias("tenant_id"), | |||
| cls.model.name, | |||
| cls.model.llm_id, | |||
| cls.model.embd_id, | |||
| cls.model.asr_id, | |||
| cls.model.img2txt_id, | |||
| UserTenant.role] | |||
| return list(cls.model.select(*fields) | |||
| .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role == UserTenantRole.NORMAL.value))) | |||
| .where(cls.model.status == StatusEnum.VALID.value).dicts()) | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def decrease(cls, user_id, num): | |||
| num = cls.model.update(credit=cls.model.credit - num).where( | |||
| cls.model.id == user_id).execute() | |||
| if num == 0: raise LookupError("Tenant not found which is supposed to be there") | |||
| if num == 0: | |||
| raise LookupError("Tenant not found which is supposed to be there") | |||
| class UserTenantService(CommonService): | |||
| model = UserTenant | |||
| @@ -13,16 +13,22 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from rag.utils import ELASTICSEARCH | |||
| from rag.nlp import search | |||
| import os | |||
| from enum import IntEnum, Enum | |||
| from api.utils import get_base_config,decrypt_database_config | |||
| from api.utils import get_base_config, decrypt_database_config | |||
| from api.utils.file_utils import get_project_base_directory | |||
| from api.utils.log_utils import LoggerFactory, getLogger | |||
| # Logger | |||
| LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "api")) | |||
| LoggerFactory.set_directory( | |||
| os.path.join( | |||
| get_project_base_directory(), | |||
| "logs", | |||
| "api")) | |||
| # {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0} | |||
| LoggerFactory.LEVEL = 10 | |||
| @@ -86,7 +92,9 @@ default_llm = { | |||
| LLM = get_base_config("user_default_llm", {}) | |||
| LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen") | |||
| if LLM_FACTORY not in default_llm: | |||
| print("\33[91m【ERROR】\33[0m:", f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.") | |||
| print( | |||
| "\33[91m【ERROR】\33[0m:", | |||
| f"LLM factory {LLM_FACTORY} has not supported yet, switch to 'Tongyi-Qianwen/QWen' automatically, and please check the API_KEY in service_conf.yaml.") | |||
| LLM_FACTORY = "Tongyi-Qianwen" | |||
| CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] | |||
| EMBEDDING_MDL = default_llm[LLM_FACTORY]["embedding_model"] | |||
| @@ -94,7 +102,9 @@ ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] | |||
| IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] | |||
| API_KEY = LLM.get("api_key", "") | |||
| PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One") | |||
| PARSERS = LLM.get( | |||
| "parsers", | |||
| "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One") | |||
| # distribution | |||
| DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False) | |||
| @@ -103,13 +113,25 @@ RAG_FLOW_UPDATE_CHECK = False | |||
| HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") | |||
| HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") | |||
| SECRET_KEY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key", "infiniflow") | |||
| TOKEN_EXPIRE_IN = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("token_expires_in", 3600) | |||
| NGINX_HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("host") or HOST | |||
| NGINX_HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("nginx", {}).get("http_port") or HTTP_PORT | |||
| RANDOM_INSTANCE_ID = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("random_instance_id", False) | |||
| SECRET_KEY = get_base_config( | |||
| RAG_FLOW_SERVICE_NAME, | |||
| {}).get( | |||
| "secret_key", | |||
| "infiniflow") | |||
| TOKEN_EXPIRE_IN = get_base_config( | |||
| RAG_FLOW_SERVICE_NAME, {}).get( | |||
| "token_expires_in", 3600) | |||
| NGINX_HOST = get_base_config( | |||
| RAG_FLOW_SERVICE_NAME, {}).get( | |||
| "nginx", {}).get("host") or HOST | |||
| NGINX_HTTP_PORT = get_base_config( | |||
| RAG_FLOW_SERVICE_NAME, {}).get( | |||
| "nginx", {}).get("http_port") or HTTP_PORT | |||
| RANDOM_INSTANCE_ID = get_base_config( | |||
| RAG_FLOW_SERVICE_NAME, {}).get( | |||
| "random_instance_id", False) | |||
| PROXY = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("proxy") | |||
| PROXY_PROTOCOL = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("protocol") | |||
| @@ -124,7 +146,9 @@ UPLOAD_DATA_FROM_CLIENT = True | |||
| AUTHENTICATION_CONF = get_base_config("authentication", {}) | |||
| # client | |||
| CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get("client", {}).get("switch", False) | |||
| CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get( | |||
| "client", {}).get( | |||
| "switch", False) | |||
| HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key") | |||
| GITHUB_OAUTH = get_base_config("oauth", {}).get("github") | |||
| WECHAT_OAUTH = get_base_config("oauth", {}).get("wechat") | |||
| @@ -147,12 +171,10 @@ USE_AUTHENTICATION = False | |||
| USE_DATA_AUTHENTICATION = False | |||
| AUTOMATIC_AUTHORIZATION_OUTPUT_DATA = True | |||
| USE_DEFAULT_TIMEOUT = False | |||
| AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s | |||
| AUTHENTICATION_DEFAULT_TIMEOUT = 7 * 24 * 60 * 60 # s | |||
| PRIVILEGE_COMMAND_WHITELIST = [] | |||
| CHECK_NODES_IDENTITY = False | |||
| from rag.nlp import search | |||
| from rag.utils import ELASTICSEARCH | |||
| retrievaler = search.Dealer(ELASTICSEARCH) | |||
| @@ -162,7 +184,7 @@ class CustomEnum(Enum): | |||
| try: | |||
| cls(value) | |||
| return True | |||
| except: | |||
| except BaseException: | |||
| return False | |||
| @classmethod | |||
| @@ -34,10 +34,12 @@ from . import file_utils | |||
| SERVICE_CONF = "service_conf.yaml" | |||
| def conf_realpath(conf_name): | |||
| conf_path = f"conf/{conf_name}" | |||
| return os.path.join(file_utils.get_project_base_directory(), conf_path) | |||
| def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict: | |||
| local_config = {} | |||
| local_path = conf_realpath(f'local.{conf_name}') | |||
| @@ -62,7 +64,8 @@ def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict: | |||
| return config.get(key, default) if key is not None else config | |||
| use_deserialize_safe_module = get_base_config('use_deserialize_safe_module', False) | |||
| use_deserialize_safe_module = get_base_config( | |||
| 'use_deserialize_safe_module', False) | |||
| class CoordinationCommunicationProtocol(object): | |||
| @@ -93,7 +96,8 @@ class BaseType: | |||
| data[_k] = _dict(vv) | |||
| else: | |||
| data = obj | |||
| return {"type": obj.__class__.__name__, "data": data, "module": module} | |||
| return {"type": obj.__class__.__name__, | |||
| "data": data, "module": module} | |||
| return _dict(self) | |||
| @@ -129,7 +133,8 @@ def rag_uuid(): | |||
| def string_to_bytes(string): | |||
| return string if isinstance(string, bytes) else string.encode(encoding="utf-8") | |||
| return string if isinstance( | |||
| string, bytes) else string.encode(encoding="utf-8") | |||
| def bytes_to_string(byte): | |||
| @@ -137,7 +142,11 @@ def bytes_to_string(byte): | |||
| def json_dumps(src, byte=False, indent=None, with_type=False): | |||
| dest = json.dumps(src, indent=indent, cls=CustomJSONEncoder, with_type=with_type) | |||
| dest = json.dumps( | |||
| src, | |||
| indent=indent, | |||
| cls=CustomJSONEncoder, | |||
| with_type=with_type) | |||
| if byte: | |||
| dest = string_to_bytes(dest) | |||
| return dest | |||
| @@ -146,7 +155,8 @@ def json_dumps(src, byte=False, indent=None, with_type=False): | |||
| def json_loads(src, object_hook=None, object_pairs_hook=None): | |||
| if isinstance(src, bytes): | |||
| src = bytes_to_string(src) | |||
| return json.loads(src, object_hook=object_hook, object_pairs_hook=object_pairs_hook) | |||
| return json.loads(src, object_hook=object_hook, | |||
| object_pairs_hook=object_pairs_hook) | |||
| def current_timestamp(): | |||
| @@ -177,7 +187,9 @@ def serialize_b64(src, to_str=False): | |||
| def deserialize_b64(src): | |||
| src = base64.b64decode(string_to_bytes(src) if isinstance(src, str) else src) | |||
| src = base64.b64decode( | |||
| string_to_bytes(src) if isinstance( | |||
| src, str) else src) | |||
| if use_deserialize_safe_module: | |||
| return restricted_loads(src) | |||
| return pickle.loads(src) | |||
| @@ -237,12 +249,14 @@ def get_lan_ip(): | |||
| pass | |||
| return ip or '' | |||
| def from_dict_hook(in_dict: dict): | |||
| if "type" in in_dict and "data" in in_dict: | |||
| if in_dict["module"] is None: | |||
| return in_dict["data"] | |||
| else: | |||
| return getattr(importlib.import_module(in_dict["module"]), in_dict["type"])(**in_dict["data"]) | |||
| return getattr(importlib.import_module( | |||
| in_dict["module"]), in_dict["type"])(**in_dict["data"]) | |||
| else: | |||
| return in_dict | |||
| @@ -259,12 +273,16 @@ def decrypt_database_password(password): | |||
| raise ValueError("No private key") | |||
| module_fun = encrypt_module.split("#") | |||
| pwdecrypt_fun = getattr(importlib.import_module(module_fun[0]), module_fun[1]) | |||
| pwdecrypt_fun = getattr( | |||
| importlib.import_module( | |||
| module_fun[0]), | |||
| module_fun[1]) | |||
| return pwdecrypt_fun(private_key, password) | |||
| def decrypt_database_config(database=None, passwd_key="password", name="database"): | |||
| def decrypt_database_config( | |||
| database=None, passwd_key="password", name="database"): | |||
| if not database: | |||
| database = get_base_config(name, {}) | |||
| @@ -275,7 +293,8 @@ def decrypt_database_config(database=None, passwd_key="password", name="database | |||
| def update_config(key, value, conf_name=SERVICE_CONF): | |||
| conf_path = conf_realpath(conf_name=conf_name) | |||
| if not os.path.isabs(conf_path): | |||
| conf_path = os.path.join(file_utils.get_project_base_directory(), conf_path) | |||
| conf_path = os.path.join( | |||
| file_utils.get_project_base_directory(), conf_path) | |||
| with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")): | |||
| config = file_utils.load_yaml_conf(conf_path=conf_path) or {} | |||
| @@ -288,7 +307,8 @@ def get_uuid(): | |||
| def datetime_format(date_time: datetime.datetime) -> datetime.datetime: | |||
| return datetime.datetime(date_time.year, date_time.month, date_time.day, date_time.hour, date_time.minute, date_time.second) | |||
| return datetime.datetime(date_time.year, date_time.month, date_time.day, | |||
| date_time.hour, date_time.minute, date_time.second) | |||
| def get_format_time() -> datetime.datetime: | |||
| @@ -307,14 +327,19 @@ def elapsed2time(elapsed): | |||
| def decrypt(line): | |||
| file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem") | |||
| file_path = os.path.join( | |||
| file_utils.get_project_base_directory(), | |||
| "conf", | |||
| "private.pem") | |||
| rsa_key = RSA.importKey(open(file_path).read(), "Welcome") | |||
| cipher = Cipher_pkcs1_v1_5.new(rsa_key) | |||
| return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode('utf-8') | |||
| return cipher.decrypt(base64.b64decode( | |||
| line), "Fail to decrypt password!").decode('utf-8') | |||
| def download_img(url): | |||
| if not url: return "" | |||
| if not url: | |||
| return "" | |||
| response = requests.get(url) | |||
| return "data:" + \ | |||
| response.headers.get('Content-Type', 'image/jpg') + ";" + \ | |||
| @@ -19,7 +19,7 @@ import time | |||
| from functools import wraps | |||
| from io import BytesIO | |||
| from flask import ( | |||
| Response, jsonify, send_file,make_response, | |||
| Response, jsonify, send_file, make_response, | |||
| request as flask_request, | |||
| ) | |||
| from werkzeug.http import HTTP_STATUS_CODES | |||
| @@ -29,7 +29,7 @@ from api.versions import get_rag_version | |||
| from api.settings import RetCode | |||
| from api.settings import ( | |||
| REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC, | |||
| stat_logger,CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY | |||
| stat_logger, CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY | |||
| ) | |||
| import requests | |||
| import functools | |||
| @@ -40,14 +40,21 @@ from hmac import HMAC | |||
| from urllib.parse import quote, urlencode | |||
| requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder) | |||
| requests.models.complexjson.dumps = functools.partial( | |||
| json.dumps, cls=CustomJSONEncoder) | |||
| def request(**kwargs): | |||
| sess = requests.Session() | |||
| stream = kwargs.pop('stream', sess.stream) | |||
| timeout = kwargs.pop('timeout', None) | |||
| kwargs['headers'] = {k.replace('_', '-').upper(): v for k, v in kwargs.get('headers', {}).items()} | |||
| kwargs['headers'] = { | |||
| k.replace( | |||
| '_', | |||
| '-').upper(): v for k, | |||
| v in kwargs.get( | |||
| 'headers', | |||
| {}).items()} | |||
| prepped = requests.Request(**kwargs).prepare() | |||
| if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY: | |||
| @@ -59,7 +66,11 @@ def request(**kwargs): | |||
| HTTP_APP_KEY.encode('ascii'), | |||
| prepped.path_url.encode('ascii'), | |||
| prepped.body if kwargs.get('json') else b'', | |||
| urlencode(sorted(kwargs['data'].items()), quote_via=quote, safe='-._~').encode('ascii') | |||
| urlencode( | |||
| sorted( | |||
| kwargs['data'].items()), | |||
| quote_via=quote, | |||
| safe='-._~').encode('ascii') | |||
| if kwargs.get('data') and isinstance(kwargs['data'], dict) else b'', | |||
| ]), 'sha1').digest()).decode('ascii') | |||
| @@ -88,11 +99,12 @@ def get_exponential_backoff_interval(retries, full_jitter=False): | |||
| return max(0, countdown) | |||
| def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None, job_id=None, meta=None): | |||
| def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', | |||
| data=None, job_id=None, meta=None): | |||
| import re | |||
| result_dict = { | |||
| "retcode": retcode, | |||
| "retmsg":retmsg, | |||
| "retmsg": retmsg, | |||
| # "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE), | |||
| "data": data, | |||
| "jobId": job_id, | |||
| @@ -107,9 +119,17 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None, job_id | |||
| response[key] = value | |||
| return jsonify(response) | |||
| def get_data_error_result(retcode=RetCode.DATA_ERROR, retmsg='Sorry! Data missing!'): | |||
| def get_data_error_result(retcode=RetCode.DATA_ERROR, | |||
| retmsg='Sorry! Data missing!'): | |||
| import re | |||
| result_dict = {"retcode": retcode, "retmsg": re.sub(r"rag", "seceum", retmsg, flags=re.IGNORECASE)} | |||
| result_dict = { | |||
| "retcode": retcode, | |||
| "retmsg": re.sub( | |||
| r"rag", | |||
| "seceum", | |||
| retmsg, | |||
| flags=re.IGNORECASE)} | |||
| response = {} | |||
| for key, value in result_dict.items(): | |||
| if value is None and key != "retcode": | |||
| @@ -118,15 +138,17 @@ def get_data_error_result(retcode=RetCode.DATA_ERROR, retmsg='Sorry! Data missin | |||
| response[key] = value | |||
| return jsonify(response) | |||
| def server_error_response(e): | |||
| stat_logger.exception(e) | |||
| try: | |||
| if e.code==401: | |||
| if e.code == 401: | |||
| return get_json_result(retcode=401, retmsg=repr(e)) | |||
| except: | |||
| except BaseException: | |||
| pass | |||
| if len(e.args) > 1: | |||
| return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1]) | |||
| return get_json_result( | |||
| retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1]) | |||
| return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e)) | |||
| @@ -162,10 +184,13 @@ def validate_request(*args, **kwargs): | |||
| if no_arguments or error_arguments: | |||
| error_string = "" | |||
| if no_arguments: | |||
| error_string += "required argument are missing: {}; ".format(",".join(no_arguments)) | |||
| error_string += "required argument are missing: {}; ".format( | |||
| ",".join(no_arguments)) | |||
| if error_arguments: | |||
| error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) | |||
| return get_json_result(retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string) | |||
| error_string += "required argument values: {}".format( | |||
| ",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) | |||
| return get_json_result( | |||
| retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string) | |||
| return func(*_args, **_kwargs) | |||
| return decorated_function | |||
| return wrapper | |||
| @@ -193,7 +218,8 @@ def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None): | |||
| return jsonify(response) | |||
| def cors_reponse(retcode=RetCode.SUCCESS, retmsg='success', data=None, auth=None): | |||
| def cors_reponse(retcode=RetCode.SUCCESS, | |||
| retmsg='success', data=None, auth=None): | |||
| result_dict = {"retcode": retcode, "retmsg": retmsg, "data": data} | |||
| response_dict = {} | |||
| for key, value in result_dict.items(): | |||
| @@ -209,4 +235,4 @@ def cors_reponse(retcode=RetCode.SUCCESS, retmsg='success', data=None, auth=None | |||
| response.headers["Access-Control-Allow-Headers"] = "*" | |||
| response.headers["Access-Control-Allow-Headers"] = "*" | |||
| response.headers["Access-Control-Expose-Headers"] = "Authorization" | |||
| return response | |||
| return response | |||
| @@ -29,6 +29,7 @@ from api.db import FileType | |||
| PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE") | |||
| RAG_BASE = os.getenv("RAG_BASE") | |||
| def get_project_base_directory(*args): | |||
| global PROJECT_BASE | |||
| if PROJECT_BASE is None: | |||
| @@ -65,7 +66,6 @@ def get_rag_python_directory(*args): | |||
| return get_rag_directory("python", *args) | |||
| @cached(cache=LRUCache(maxsize=10)) | |||
| def load_json_conf(conf_path): | |||
| if os.path.isabs(conf_path): | |||
| @@ -146,10 +146,12 @@ def filename_type(filename): | |||
| if re.match(r".*\.pdf$", filename): | |||
| return FileType.PDF.value | |||
| if re.match(r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md)$", filename): | |||
| if re.match( | |||
| r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md)$", filename): | |||
| return FileType.DOC.value | |||
| if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): | |||
| if re.match( | |||
| r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): | |||
| return FileType.AURAL.value | |||
| if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename): | |||
| @@ -164,14 +166,16 @@ def thumbnail(filename, blob): | |||
| buffered = BytesIO() | |||
| Image.frombytes("RGB", [pix.width, pix.height], | |||
| pix.samples).save(buffered, format="png") | |||
| return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8") | |||
| return "data:image/png;base64," + \ | |||
| base64.b64encode(buffered.getvalue()).decode("utf-8") | |||
| if re.match(r".*\.(jpg|jpeg|png|tif|gif|icon|ico|webp)$", filename): | |||
| image = Image.open(BytesIO(blob)) | |||
| image.thumbnail((30, 30)) | |||
| buffered = BytesIO() | |||
| image.save(buffered, format="png") | |||
| return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8") | |||
| return "data:image/png;base64," + \ | |||
| base64.b64encode(buffered.getvalue()).decode("utf-8") | |||
| if re.match(r".*\.(ppt|pptx)$", filename): | |||
| import aspose.slides as slides | |||
| @@ -179,8 +183,10 @@ def thumbnail(filename, blob): | |||
| try: | |||
| with slides.Presentation(BytesIO(blob)) as presentation: | |||
| buffered = BytesIO() | |||
| presentation.slides[0].get_thumbnail(0.03, 0.03).save(buffered, drawing.imaging.ImageFormat.png) | |||
| return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8") | |||
| presentation.slides[0].get_thumbnail(0.03, 0.03).save( | |||
| buffered, drawing.imaging.ImageFormat.png) | |||
| return "data:image/png;base64," + \ | |||
| base64.b64encode(buffered.getvalue()).decode("utf-8") | |||
| except Exception as e: | |||
| pass | |||
| @@ -190,6 +196,3 @@ def traversal_files(base): | |||
| for f in fs: | |||
| fullname = os.path.join(root, f) | |||
| yield fullname | |||
| @@ -23,6 +23,7 @@ from threading import RLock | |||
| from api.utils import file_utils | |||
| class LoggerFactory(object): | |||
| TYPE = "FILE" | |||
| LOG_FORMAT = "[%(levelname)s] [%(asctime)s] [jobId] [%(process)s:%(thread)s] - [%(module)s.%(funcName)s] [line:%(lineno)d]: %(message)s" | |||
| @@ -49,7 +50,8 @@ class LoggerFactory(object): | |||
| schedule_logger_dict = {} | |||
| @staticmethod | |||
| def set_directory(directory=None, parent_log_dir=None, append_to_parent_log=None, force=False): | |||
| def set_directory(directory=None, parent_log_dir=None, | |||
| append_to_parent_log=None, force=False): | |||
| if parent_log_dir: | |||
| LoggerFactory.PARENT_LOG_DIR = parent_log_dir | |||
| if append_to_parent_log: | |||
| @@ -66,11 +68,13 @@ class LoggerFactory(object): | |||
| else: | |||
| os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True) | |||
| for loggerName, ghandler in LoggerFactory.global_handler_dict.items(): | |||
| for className, (logger, handler) in LoggerFactory.logger_dict.items(): | |||
| for className, (logger, | |||
| handler) in LoggerFactory.logger_dict.items(): | |||
| logger.removeHandler(ghandler) | |||
| ghandler.close() | |||
| LoggerFactory.global_handler_dict = {} | |||
| for className, (logger, handler) in LoggerFactory.logger_dict.items(): | |||
| for className, (logger, | |||
| handler) in LoggerFactory.logger_dict.items(): | |||
| logger.removeHandler(handler) | |||
| _handler = None | |||
| if handler: | |||
| @@ -111,19 +115,23 @@ class LoggerFactory(object): | |||
| if logger_name_key not in LoggerFactory.global_handler_dict: | |||
| with LoggerFactory.lock: | |||
| if logger_name_key not in LoggerFactory.global_handler_dict: | |||
| handler = LoggerFactory.get_handler(logger_name, level, log_dir) | |||
| handler = LoggerFactory.get_handler( | |||
| logger_name, level, log_dir) | |||
| LoggerFactory.global_handler_dict[logger_name_key] = handler | |||
| return LoggerFactory.global_handler_dict[logger_name_key] | |||
| @staticmethod | |||
| def get_handler(class_name, level=None, log_dir=None, log_type=None, job_id=None): | |||
| def get_handler(class_name, level=None, log_dir=None, | |||
| log_type=None, job_id=None): | |||
| if not log_type: | |||
| if not LoggerFactory.LOG_DIR or not class_name: | |||
| return logging.StreamHandler() | |||
| # return Diy_StreamHandler() | |||
| if not log_dir: | |||
| log_file = os.path.join(LoggerFactory.LOG_DIR, "{}.log".format(class_name)) | |||
| log_file = os.path.join( | |||
| LoggerFactory.LOG_DIR, | |||
| "{}.log".format(class_name)) | |||
| else: | |||
| log_file = os.path.join(log_dir, "{}.log".format(class_name)) | |||
| else: | |||
| @@ -133,16 +141,16 @@ class LoggerFactory(object): | |||
| os.makedirs(os.path.dirname(log_file), exist_ok=True) | |||
| if LoggerFactory.log_share: | |||
| handler = ROpenHandler(log_file, | |||
| when='D', | |||
| interval=1, | |||
| backupCount=14, | |||
| delay=True) | |||
| when='D', | |||
| interval=1, | |||
| backupCount=14, | |||
| delay=True) | |||
| else: | |||
| handler = TimedRotatingFileHandler(log_file, | |||
| when='D', | |||
| interval=1, | |||
| backupCount=14, | |||
| delay=True) | |||
| when='D', | |||
| interval=1, | |||
| backupCount=14, | |||
| delay=True) | |||
| if level: | |||
| handler.level = level | |||
| @@ -170,7 +178,9 @@ class LoggerFactory(object): | |||
| for level in LoggerFactory.levels: | |||
| if level >= LoggerFactory.LEVEL: | |||
| level_logger_name = logging._levelToName[level] | |||
| logger.addHandler(LoggerFactory.get_global_handler(level_logger_name, level)) | |||
| logger.addHandler( | |||
| LoggerFactory.get_global_handler( | |||
| level_logger_name, level)) | |||
| if LoggerFactory.append_to_parent_log and LoggerFactory.PARENT_LOG_DIR: | |||
| for level in LoggerFactory.levels: | |||
| if level >= LoggerFactory.LEVEL: | |||
| @@ -224,22 +234,26 @@ def start_log(msg, job=None, task=None, role=None, party_id=None, detail=None): | |||
| return f"{prefix}start to {msg}{suffix}" | |||
| def successful_log(msg, job=None, task=None, role=None, party_id=None, detail=None): | |||
| def successful_log(msg, job=None, task=None, role=None, | |||
| party_id=None, detail=None): | |||
| prefix, suffix = base_msg(job, task, role, party_id, detail) | |||
| return f"{prefix}{msg} successfully{suffix}" | |||
| def warning_log(msg, job=None, task=None, role=None, party_id=None, detail=None): | |||
| def warning_log(msg, job=None, task=None, role=None, | |||
| party_id=None, detail=None): | |||
| prefix, suffix = base_msg(job, task, role, party_id, detail) | |||
| return f"{prefix}{msg} is not effective{suffix}" | |||
| def failed_log(msg, job=None, task=None, role=None, party_id=None, detail=None): | |||
| def failed_log(msg, job=None, task=None, role=None, | |||
| party_id=None, detail=None): | |||
| prefix, suffix = base_msg(job, task, role, party_id, detail) | |||
| return f"{prefix}failed to {msg}{suffix}" | |||
| def base_msg(job=None, task=None, role: str = None, party_id: typing.Union[str, int] = None, detail=None): | |||
| def base_msg(job=None, task=None, role: str = None, | |||
| party_id: typing.Union[str, int] = None, detail=None): | |||
| if detail: | |||
| detail_msg = f" detail: \n{detail}" | |||
| else: | |||
| @@ -285,10 +299,14 @@ def get_job_logger(job_id, log_type): | |||
| for job_log_dir in log_dirs: | |||
| handler = LoggerFactory.get_handler(class_name=None, level=LoggerFactory.LEVEL, | |||
| log_dir=job_log_dir, log_type=log_type, job_id=job_id) | |||
| error_handler = LoggerFactory.get_handler(class_name=None, level=logging.ERROR, log_dir=job_log_dir, log_type=log_type, job_id=job_id) | |||
| error_handler = LoggerFactory.get_handler( | |||
| class_name=None, | |||
| level=logging.ERROR, | |||
| log_dir=job_log_dir, | |||
| log_type=log_type, | |||
| job_id=job_id) | |||
| logger.addHandler(handler) | |||
| logger.addHandler(error_handler) | |||
| with LoggerFactory.lock: | |||
| LoggerFactory.schedule_logger_dict[job_id + log_type] = logger | |||
| return logger | |||
| @@ -1,18 +1,23 @@ | |||
| import base64, os, sys | |||
| import base64 | |||
| import os | |||
| import sys | |||
| from Cryptodome.PublicKey import RSA | |||
| from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5 | |||
| from api.utils import decrypt, file_utils | |||
| def crypt(line): | |||
| file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "public.pem") | |||
| file_path = os.path.join( | |||
| file_utils.get_project_base_directory(), | |||
| "conf", | |||
| "public.pem") | |||
| rsa_key = RSA.importKey(open(file_path).read()) | |||
| cipher = Cipher_pkcs1_v1_5.new(rsa_key) | |||
| return base64.b64encode(cipher.encrypt(line.encode('utf-8'))).decode("utf-8") | |||
| return base64.b64encode(cipher.encrypt( | |||
| line.encode('utf-8'))).decode("utf-8") | |||
| if __name__ == "__main__": | |||
| pswd = crypt(sys.argv[1]) | |||
| print(pswd) | |||
| print(decrypt(pswd)) | |||
| @@ -4,5 +4,3 @@ from .pdf_parser import HuParser as PdfParser, PlainParser | |||
| from .docx_parser import HuDocxParser as DocxParser | |||
| from .excel_parser import HuExcelParser as ExcelParser | |||
| from .ppt_parser import HuPptParser as PptParser | |||
| @@ -99,12 +99,15 @@ class HuDocxParser: | |||
| return ["\n".join(lines)] | |||
| def __call__(self, fnm, from_page=0, to_page=100000): | |||
| self.doc = Document(fnm) if isinstance(fnm, str) else Document(BytesIO(fnm)) | |||
| self.doc = Document(fnm) if isinstance( | |||
| fnm, str) else Document(BytesIO(fnm)) | |||
| pn = 0 | |||
| secs = [] | |||
| for p in self.doc.paragraphs: | |||
| if pn > to_page: break | |||
| if from_page <= pn < to_page and p.text.strip(): secs.append((p.text, p.style.name)) | |||
| if pn > to_page: | |||
| break | |||
| if from_page <= pn < to_page and p.text.strip(): | |||
| secs.append((p.text, p.style.name)) | |||
| for run in p.runs: | |||
| if 'lastRenderedPageBreak' in run._element.xml: | |||
| pn += 1 | |||
| @@ -15,13 +15,16 @@ class HuExcelParser: | |||
| ws = wb[sheetname] | |||
| rows = list(ws.rows) | |||
| tb += f"<table><caption>{sheetname}</caption><tr>" | |||
| for t in list(rows[0]): tb += f"<th>{t.value}</th>" | |||
| for t in list(rows[0]): | |||
| tb += f"<th>{t.value}</th>" | |||
| tb += "</tr>" | |||
| for r in list(rows[1:]): | |||
| tb += "<tr>" | |||
| for i,c in enumerate(r): | |||
| if c.value is None: tb += "<td></td>" | |||
| else: tb += f"<td>{c.value}</td>" | |||
| for i, c in enumerate(r): | |||
| if c.value is None: | |||
| tb += "<td></td>" | |||
| else: | |||
| tb += f"<td>{c.value}</td>" | |||
| tb += "</tr>" | |||
| tb += "</table>\n" | |||
| return tb | |||
| @@ -38,13 +41,15 @@ class HuExcelParser: | |||
| ti = list(rows[0]) | |||
| for r in list(rows[1:]): | |||
| l = [] | |||
| for i,c in enumerate(r): | |||
| if not c.value:continue | |||
| for i, c in enumerate(r): | |||
| if not c.value: | |||
| continue | |||
| t = str(ti[i].value) if i < len(ti) else "" | |||
| t += (":" if t else "") + str(c.value) | |||
| l.append(t) | |||
| l = "; ".join(l) | |||
| if sheetname.lower().find("sheet") <0: l += " ——"+sheetname | |||
| if sheetname.lower().find("sheet") < 0: | |||
| l += " ——" + sheetname | |||
| res.append(l) | |||
| return res | |||
| @@ -43,9 +43,11 @@ class HuParser: | |||
| "rag/res/deepdoc"), | |||
| local_files_only=True) | |||
| except Exception as e: | |||
| model_dir = snapshot_download(repo_id="InfiniFlow/text_concat_xgb_v1.0") | |||
| model_dir = snapshot_download( | |||
| repo_id="InfiniFlow/text_concat_xgb_v1.0") | |||
| self.updown_cnt_mdl.load_model(os.path.join(model_dir, "updown_concat_xgb.model")) | |||
| self.updown_cnt_mdl.load_model(os.path.join( | |||
| model_dir, "updown_concat_xgb.model")) | |||
| self.page_from = 0 | |||
| """ | |||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | |||
| @@ -72,7 +74,7 @@ class HuParser: | |||
| def _y_dis( | |||
| self, a, b): | |||
| return ( | |||
| b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2 | |||
| b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2 | |||
| def _match_proj(self, b): | |||
| proj_patt = [ | |||
| @@ -95,9 +97,9 @@ class HuParser: | |||
| tks_down = huqie.qie(down["text"][:LEN]).split(" ") | |||
| tks_up = huqie.qie(up["text"][-LEN:]).split(" ") | |||
| tks_all = up["text"][-LEN:].strip() \ | |||
| + (" " if re.match(r"[a-zA-Z0-9]+", | |||
| up["text"][-1] + down["text"][0]) else "") \ | |||
| + down["text"][:LEN].strip() | |||
| + (" " if re.match(r"[a-zA-Z0-9]+", | |||
| up["text"][-1] + down["text"][0]) else "") \ | |||
| + down["text"][:LEN].strip() | |||
| tks_all = huqie.qie(tks_all).split(" ") | |||
| fea = [ | |||
| up.get("R", -1) == down.get("R", -1), | |||
| @@ -119,7 +121,7 @@ class HuParser: | |||
| True if re.search(r"[,,][^。.]+$", up["text"]) else False, | |||
| True if re.search(r"[,,][^。.]+$", up["text"]) else False, | |||
| True if re.search(r"[\((][^\))]+$", up["text"]) | |||
| and re.search(r"[\))]", down["text"]) else False, | |||
| and re.search(r"[\))]", down["text"]) else False, | |||
| self._match_proj(down), | |||
| True if re.match(r"[A-Z]", down["text"]) else False, | |||
| True if re.match(r"[A-Z]", up["text"][-1]) else False, | |||
| @@ -181,7 +183,7 @@ class HuParser: | |||
| continue | |||
| for tb in tbls: # for table | |||
| left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \ | |||
| tb["x1"] + MARGIN, tb["bottom"] + MARGIN | |||
| tb["x1"] + MARGIN, tb["bottom"] + MARGIN | |||
| left *= ZM | |||
| top *= ZM | |||
| right *= ZM | |||
| @@ -235,7 +237,8 @@ class HuParser: | |||
| b["R_top"] = rows[ii]["top"] | |||
| b["R_bott"] = rows[ii]["bottom"] | |||
| ii = Recognizer.find_overlapped_with_threashold(b, headers, thr=0.3) | |||
| ii = Recognizer.find_overlapped_with_threashold( | |||
| b, headers, thr=0.3) | |||
| if ii is not None: | |||
| b["H_top"] = headers[ii]["top"] | |||
| b["H_bott"] = headers[ii]["bottom"] | |||
| @@ -272,7 +275,8 @@ class HuParser: | |||
| ) | |||
| # merge chars in the same rect | |||
| for c in Recognizer.sort_X_firstly(chars, self.mean_width[pagenum - 1] // 4): | |||
| for c in Recognizer.sort_X_firstly( | |||
| chars, self.mean_width[pagenum - 1] // 4): | |||
| ii = Recognizer.find_overlapped(c, bxs) | |||
| if ii is None: | |||
| self.lefted_chars.append(c) | |||
| @@ -283,13 +287,15 @@ class HuParser: | |||
| self.lefted_chars.append(c) | |||
| continue | |||
| if c["text"] == " " and bxs[ii]["text"]: | |||
| if re.match(r"[0-9a-zA-Z,.?;:!%%]", bxs[ii]["text"][-1]): bxs[ii]["text"] += " " | |||
| if re.match(r"[0-9a-zA-Z,.?;:!%%]", bxs[ii]["text"][-1]): | |||
| bxs[ii]["text"] += " " | |||
| else: | |||
| bxs[ii]["text"] += c["text"] | |||
| for b in bxs: | |||
| if not b["text"]: | |||
| left, right, top, bott = b["x0"] * ZM, b["x1"] * ZM, b["top"] * ZM, b["bottom"] * ZM | |||
| left, right, top, bott = b["x0"] * ZM, b["x1"] * \ | |||
| ZM, b["top"] * ZM, b["bottom"] * ZM | |||
| b["text"] = self.ocr.recognize(np.array(img), | |||
| np.array([[left, top], [right, top], [right, bott], [left, bott]], | |||
| dtype=np.float32)) | |||
| @@ -302,7 +308,8 @@ class HuParser: | |||
| def _layouts_rec(self, ZM, drop=True): | |||
| assert len(self.page_images) == len(self.boxes) | |||
| self.boxes, self.page_layout = self.layouter(self.page_images, self.boxes, ZM, drop=drop) | |||
| self.boxes, self.page_layout = self.layouter( | |||
| self.page_images, self.boxes, ZM, drop=drop) | |||
| # cumlative Y | |||
| for i in range(len(self.boxes)): | |||
| self.boxes[i]["top"] += \ | |||
| @@ -332,7 +339,8 @@ class HuParser: | |||
| "equation"]: | |||
| i += 1 | |||
| continue | |||
| if abs(self._y_dis(b, b_)) < self.mean_height[bxs[i]["page_number"] - 1] / 3: | |||
| if abs(self._y_dis(b, b_) | |||
| ) < self.mean_height[bxs[i]["page_number"] - 1] / 3: | |||
| # merge | |||
| bxs[i]["x1"] = b_["x1"] | |||
| bxs[i]["top"] = (b["top"] + b_["top"]) / 2 | |||
| @@ -366,12 +374,15 @@ class HuParser: | |||
| self.boxes = bxs | |||
| def _naive_vertical_merge(self): | |||
| bxs = Recognizer.sort_Y_firstly(self.boxes, np.median(self.mean_height) / 3) | |||
| bxs = Recognizer.sort_Y_firstly( | |||
| self.boxes, np.median( | |||
| self.mean_height) / 3) | |||
| i = 0 | |||
| while i + 1 < len(bxs): | |||
| b = bxs[i] | |||
| b_ = bxs[i + 1] | |||
| if b["page_number"] < b_["page_number"] and re.match(r"[0-9 •一—-]+$", b["text"]): | |||
| if b["page_number"] < b_["page_number"] and re.match( | |||
| r"[0-9 •一—-]+$", b["text"]): | |||
| bxs.pop(i) | |||
| continue | |||
| if not b["text"].strip(): | |||
| @@ -379,7 +390,8 @@ class HuParser: | |||
| continue | |||
| concatting_feats = [ | |||
| b["text"].strip()[-1] in ",;:'\",、‘“;:-", | |||
| len(b["text"].strip()) > 1 and b["text"].strip()[-2] in ",;:'\",‘“、;:", | |||
| len(b["text"].strip()) > 1 and b["text"].strip( | |||
| )[-2] in ",;:'\",‘“、;:", | |||
| b["text"].strip()[0] in "。;?!?”)),,、:", | |||
| ] | |||
| # features for not concating | |||
| @@ -387,7 +399,7 @@ class HuParser: | |||
| b.get("layoutno", 0) != b.get("layoutno", 0), | |||
| b["text"].strip()[-1] in "。?!?", | |||
| self.is_english and b["text"].strip()[-1] in ".!?", | |||
| b["page_number"] == b_["page_number"] and b_["top"] - \ | |||
| b["page_number"] == b_["page_number"] and b_["top"] - | |||
| b["bottom"] > self.mean_height[b["page_number"] - 1] * 1.5, | |||
| b["page_number"] < b_["page_number"] and abs( | |||
| b["x0"] - b_["x0"]) > self.mean_width[b["page_number"] - 1] * 4, | |||
| @@ -396,7 +408,12 @@ class HuParser: | |||
| detach_feats = [b["x1"] < b_["x0"], | |||
| b["x0"] > b_["x1"]] | |||
| if (any(feats) and not any(concatting_feats)) or any(detach_feats): | |||
| print(b["text"], b_["text"], any(feats), any(concatting_feats), any(detach_feats)) | |||
| print( | |||
| b["text"], | |||
| b_["text"], | |||
| any(feats), | |||
| any(concatting_feats), | |||
| any(detach_feats)) | |||
| i += 1 | |||
| continue | |||
| # merge up and down | |||
| @@ -526,31 +543,39 @@ class HuParser: | |||
| i += 1 | |||
| continue | |||
| findit = True | |||
| eng = re.match(r"[0-9a-zA-Z :'.-]{5,}", self.boxes[i]["text"].strip()) | |||
| eng = re.match( | |||
| r"[0-9a-zA-Z :'.-]{5,}", | |||
| self.boxes[i]["text"].strip()) | |||
| self.boxes.pop(i) | |||
| if i >= len(self.boxes): break | |||
| if i >= len(self.boxes): | |||
| break | |||
| prefix = self.boxes[i]["text"].strip()[:3] if not eng else " ".join( | |||
| self.boxes[i]["text"].strip().split(" ")[:2]) | |||
| while not prefix: | |||
| self.boxes.pop(i) | |||
| if i >= len(self.boxes): break | |||
| if i >= len(self.boxes): | |||
| break | |||
| prefix = self.boxes[i]["text"].strip()[:3] if not eng else " ".join( | |||
| self.boxes[i]["text"].strip().split(" ")[:2]) | |||
| self.boxes.pop(i) | |||
| if i >= len(self.boxes) or not prefix: break | |||
| if i >= len(self.boxes) or not prefix: | |||
| break | |||
| for j in range(i, min(i + 128, len(self.boxes))): | |||
| if not re.match(prefix, self.boxes[j]["text"]): | |||
| continue | |||
| for k in range(i, j): self.boxes.pop(i) | |||
| for k in range(i, j): | |||
| self.boxes.pop(i) | |||
| break | |||
| if findit: return | |||
| if findit: | |||
| return | |||
| page_dirty = [0] * len(self.page_images) | |||
| for b in self.boxes: | |||
| if re.search(r"(··|··|··)", b["text"]): | |||
| page_dirty[b["page_number"] - 1] += 1 | |||
| page_dirty = set([i + 1 for i, t in enumerate(page_dirty) if t > 3]) | |||
| if not page_dirty: return | |||
| if not page_dirty: | |||
| return | |||
| i = 0 | |||
| while i < len(self.boxes): | |||
| if self.boxes[i]["page_number"] in page_dirty: | |||
| @@ -582,7 +607,8 @@ class HuParser: | |||
| b_["top"] = b["top"] | |||
| self.boxes.pop(i) | |||
| def _extract_table_figure(self, need_image, ZM, return_html, need_position): | |||
| def _extract_table_figure(self, need_image, ZM, | |||
| return_html, need_position): | |||
| tables = {} | |||
| figures = {} | |||
| # extract figure and table boxes | |||
| @@ -594,7 +620,7 @@ class HuParser: | |||
| i += 1 | |||
| continue | |||
| lout_no = str(self.boxes[i]["page_number"]) + \ | |||
| "-" + str(self.boxes[i]["layoutno"]) | |||
| "-" + str(self.boxes[i]["layoutno"]) | |||
| if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", | |||
| "title", | |||
| "figure caption", | |||
| @@ -761,7 +787,8 @@ class HuParser: | |||
| for k, bxs in tables.items(): | |||
| if not bxs: | |||
| continue | |||
| bxs = Recognizer.sort_Y_firstly(bxs, np.mean([(b["bottom"] - b["top"]) / 2 for b in bxs])) | |||
| bxs = Recognizer.sort_Y_firstly(bxs, np.mean( | |||
| [(b["bottom"] - b["top"]) / 2 for b in bxs])) | |||
| poss = [] | |||
| res.append((cropout(bxs, "table", poss), | |||
| self.tbl_det.construct_table(bxs, html=return_html, is_english=self.is_english))) | |||
| @@ -769,7 +796,8 @@ class HuParser: | |||
| assert len(positions) == len(res) | |||
| if need_position: return list(zip(res, positions)) | |||
| if need_position: | |||
| return list(zip(res, positions)) | |||
| return res | |||
| def proj_match(self, line): | |||
| @@ -873,7 +901,8 @@ class HuParser: | |||
| boxes.pop(0) | |||
| mw = np.mean(widths) | |||
| if mj or mw / pw >= 0.35 or mw > 200: | |||
| res.append("\n".join([c["text"] + self._line_tag(c, ZM) for c in lines])) | |||
| res.append( | |||
| "\n".join([c["text"] + self._line_tag(c, ZM) for c in lines])) | |||
| else: | |||
| logging.debug("REMOVED: " + | |||
| "<<".join([c["text"] for c in lines])) | |||
| @@ -883,13 +912,16 @@ class HuParser: | |||
| @staticmethod | |||
| def total_page_number(fnm, binary=None): | |||
| try: | |||
| pdf = pdfplumber.open(fnm) if not binary else pdfplumber.open(BytesIO(binary)) | |||
| pdf = pdfplumber.open( | |||
| fnm) if not binary else pdfplumber.open(BytesIO(binary)) | |||
| return len(pdf.pages) | |||
| except Exception as e: | |||
| pdf = fitz.open(fnm) if not binary else fitz.open(stream=fnm, filetype="pdf") | |||
| pdf = fitz.open(fnm) if not binary else fitz.open( | |||
| stream=fnm, filetype="pdf") | |||
| return len(pdf) | |||
| def __images__(self, fnm, zoomin=3, page_from=0, page_to=299, callback=None): | |||
| def __images__(self, fnm, zoomin=3, page_from=0, | |||
| page_to=299, callback=None): | |||
| self.lefted_chars = [] | |||
| self.mean_height = [] | |||
| self.mean_width = [] | |||
| @@ -899,21 +931,26 @@ class HuParser: | |||
| self.page_layout = [] | |||
| self.page_from = page_from | |||
| try: | |||
| self.pdf = pdfplumber.open(fnm) if isinstance(fnm, str) else pdfplumber.open(BytesIO(fnm)) | |||
| self.pdf = pdfplumber.open(fnm) if isinstance( | |||
| fnm, str) else pdfplumber.open(BytesIO(fnm)) | |||
| self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in | |||
| enumerate(self.pdf.pages[page_from:page_to])] | |||
| self.page_chars = [[c for c in page.chars if self._has_color(c)] for page in | |||
| self.pdf.pages[page_from:page_to]] | |||
| self.total_page = len(self.pdf.pages) | |||
| except Exception as e: | |||
| self.pdf = fitz.open(fnm) if isinstance(fnm, str) else fitz.open(stream=fnm, filetype="pdf") | |||
| self.pdf = fitz.open(fnm) if isinstance( | |||
| fnm, str) else fitz.open( | |||
| stream=fnm, filetype="pdf") | |||
| self.page_images = [] | |||
| self.page_chars = [] | |||
| mat = fitz.Matrix(zoomin, zoomin) | |||
| self.total_page = len(self.pdf) | |||
| for i, page in enumerate(self.pdf): | |||
| if i < page_from: continue | |||
| if i >= page_to: break | |||
| if i < page_from: | |||
| continue | |||
| if i >= page_to: | |||
| break | |||
| pix = page.get_pixmap(matrix=mat) | |||
| img = Image.frombytes("RGB", [pix.width, pix.height], | |||
| pix.samples) | |||
| @@ -930,7 +967,7 @@ class HuParser: | |||
| if isinstance(a, dict): | |||
| self.outlines.append((a["/Title"], depth)) | |||
| continue | |||
| dfs(a, depth+1) | |||
| dfs(a, depth + 1) | |||
| dfs(outlines, 0) | |||
| except Exception as e: | |||
| logging.warning(f"Outlines exception: {e}") | |||
| @@ -940,8 +977,9 @@ class HuParser: | |||
| logging.info("Images converted.") | |||
| self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join( | |||
| random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in | |||
| range(len(self.page_chars))] | |||
| if sum([1 if e else 0 for e in self.is_english]) > len(self.page_images) / 2: | |||
| range(len(self.page_chars))] | |||
| if sum([1 if e else 0 for e in self.is_english]) > len( | |||
| self.page_images) / 2: | |||
| self.is_english = True | |||
| else: | |||
| self.is_english = False | |||
| @@ -970,9 +1008,11 @@ class HuParser: | |||
| # self.page_cum_height.append( | |||
| # np.max([c["bottom"] for c in chars])) | |||
| self.__ocr(i + 1, img, chars, zoomin) | |||
| if callback: callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="") | |||
| if callback: | |||
| callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="") | |||
| if not self.is_english and not any([c for c in self.page_chars]) and self.boxes: | |||
| if not self.is_english and not any( | |||
| [c for c in self.page_chars]) and self.boxes: | |||
| bxes = [b for bxs in self.boxes for b in bxs] | |||
| self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", | |||
| "".join([b["text"] for b in random.choices(bxes, k=min(30, len(bxes)))])) | |||
| @@ -989,7 +1029,8 @@ class HuParser: | |||
| self._text_merge() | |||
| self._concat_downward() | |||
| self._filter_forpages() | |||
| tbls = self._extract_table_figure(need_image, zoomin, return_html, False) | |||
| tbls = self._extract_table_figure( | |||
| need_image, zoomin, return_html, False) | |||
| return self.__filterout_scraps(deepcopy(self.boxes), zoomin), tbls | |||
| def remove_tag(self, txt): | |||
| @@ -1003,15 +1044,19 @@ class HuParser: | |||
| "#").strip("@").split("\t") | |||
| left, right, top, bottom = float(left), float( | |||
| right), float(top), float(bottom) | |||
| poss.append(([int(p) - 1 for p in pn.split("-")], left, right, top, bottom)) | |||
| poss.append(([int(p) - 1 for p in pn.split("-")], | |||
| left, right, top, bottom)) | |||
| if not poss: | |||
| if need_position: return None, None | |||
| if need_position: | |||
| return None, None | |||
| return | |||
| max_width = max(np.max([right - left for (_, left, right, _, _) in poss]), 6) | |||
| max_width = max( | |||
| np.max([right - left for (_, left, right, _, _) in poss]), 6) | |||
| GAP = 6 | |||
| pos = poss[0] | |||
| poss.insert(0, ([pos[0][0]], pos[1], pos[2], max(0, pos[3] - 120), max(pos[3] - GAP, 0))) | |||
| poss.insert(0, ([pos[0][0]], pos[1], pos[2], max( | |||
| 0, pos[3] - 120), max(pos[3] - GAP, 0))) | |||
| pos = poss[-1] | |||
| poss.append(([pos[0][-1]], pos[1], pos[2], min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + GAP), | |||
| min(self.page_images[pos[0][-1]].size[1] / ZM, pos[4] + 120))) | |||
| @@ -1026,7 +1071,7 @@ class HuParser: | |||
| self.page_images[pns[0]].crop((left * ZM, top * ZM, | |||
| right * | |||
| ZM, min( | |||
| bottom, self.page_images[pns[0]].size[1]) | |||
| bottom, self.page_images[pns[0]].size[1]) | |||
| )) | |||
| ) | |||
| if 0 < ii < len(poss) - 1: | |||
| @@ -1047,7 +1092,8 @@ class HuParser: | |||
| bottom -= self.page_images[pn].size[1] | |||
| if not imgs: | |||
| if need_position: return None, None | |||
| if need_position: | |||
| return None, None | |||
| return | |||
| height = 0 | |||
| for img in imgs: | |||
| @@ -1076,12 +1122,14 @@ class HuParser: | |||
| pn = bx["page_number"] | |||
| top = bx["top"] - self.page_cum_height[pn - 1] | |||
| bott = bx["bottom"] - self.page_cum_height[pn - 1] | |||
| poss.append((pn, bx["x0"], bx["x1"], top, min(bott, self.page_images[pn - 1].size[1] / ZM))) | |||
| poss.append((pn, bx["x0"], bx["x1"], top, min( | |||
| bott, self.page_images[pn - 1].size[1] / ZM))) | |||
| while bott * ZM > self.page_images[pn - 1].size[1]: | |||
| bott -= self.page_images[pn - 1].size[1] / ZM | |||
| top = 0 | |||
| pn += 1 | |||
| poss.append((pn, bx["x0"], bx["x1"], top, min(bott, self.page_images[pn - 1].size[1] / ZM))) | |||
| poss.append((pn, bx["x0"], bx["x1"], top, min( | |||
| bott, self.page_images[pn - 1].size[1] / ZM))) | |||
| return poss | |||
| @@ -1090,11 +1138,14 @@ class PlainParser(object): | |||
| self.outlines = [] | |||
| lines = [] | |||
| try: | |||
| self.pdf = pdf2_read(filename if isinstance(filename, str) else BytesIO(filename)) | |||
| self.pdf = pdf2_read( | |||
| filename if isinstance( | |||
| filename, str) else BytesIO(filename)) | |||
| for page in self.pdf.pages[from_page:to_page]: | |||
| lines.extend([t for t in page.extract_text().split("\n")]) | |||
| outlines = self.pdf.outline | |||
| def dfs(arr, depth): | |||
| for a in arr: | |||
| if isinstance(a, dict): | |||
| @@ -1117,5 +1168,6 @@ class PlainParser(object): | |||
| def remove_tag(txt): | |||
| raise NotImplementedError | |||
| if __name__ == "__main__": | |||
| pass | |||
| @@ -23,7 +23,8 @@ class HuPptParser(object): | |||
| tb = shape.table | |||
| rows = [] | |||
| for i in range(1, len(tb.rows)): | |||
| rows.append("; ".join([tb.cell(0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)])) | |||
| rows.append("; ".join([tb.cell( | |||
| 0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)])) | |||
| return "\n".join(rows) | |||
| if shape.has_text_frame: | |||
| @@ -31,9 +32,10 @@ class HuPptParser(object): | |||
| if shape.shape_type == 6: | |||
| texts = [] | |||
| for p in sorted(shape.shapes, key=lambda x: (x.top//10, x.left)): | |||
| for p in sorted(shape.shapes, key=lambda x: (x.top // 10, x.left)): | |||
| t = self.__extract(p) | |||
| if t: texts.append(t) | |||
| if t: | |||
| texts.append(t) | |||
| return "\n".join(texts) | |||
| def __call__(self, fnm, from_page, to_page, callback=None): | |||
| @@ -43,12 +45,16 @@ class HuPptParser(object): | |||
| txts = [] | |||
| self.total_page = len(ppt.slides) | |||
| for i, slide in enumerate(ppt.slides): | |||
| if i < from_page: continue | |||
| if i >= to_page:break | |||
| if i < from_page: | |||
| continue | |||
| if i >= to_page: | |||
| break | |||
| texts = [] | |||
| for shape in sorted(slide.shapes, key=lambda x: (x.top//10, x.left)): | |||
| for shape in sorted( | |||
| slide.shapes, key=lambda x: (x.top // 10, x.left)): | |||
| txt = self.__extract(shape) | |||
| if txt: texts.append(txt) | |||
| if txt: | |||
| texts.append(txt) | |||
| txts.append("\n".join(texts)) | |||
| return txts | |||
| @@ -24,18 +24,19 @@ from deepdoc.vision import Recognizer | |||
| class LayoutRecognizer(Recognizer): | |||
| labels = [ | |||
| "_background_", | |||
| "Text", | |||
| "Title", | |||
| "Figure", | |||
| "Figure caption", | |||
| "Table", | |||
| "Table caption", | |||
| "Header", | |||
| "Footer", | |||
| "Reference", | |||
| "Equation", | |||
| ] | |||
| "_background_", | |||
| "Text", | |||
| "Title", | |||
| "Figure", | |||
| "Figure caption", | |||
| "Table", | |||
| "Table caption", | |||
| "Header", | |||
| "Footer", | |||
| "Reference", | |||
| "Equation", | |||
| ] | |||
| def __init__(self, domain): | |||
| try: | |||
| model_dir = snapshot_download( | |||
| @@ -47,10 +48,12 @@ class LayoutRecognizer(Recognizer): | |||
| except Exception as e: | |||
| model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc") | |||
| super().__init__(self.labels, domain, model_dir)#os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||
| # os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||
| super().__init__(self.labels, domain, model_dir) | |||
| self.garbage_layouts = ["footer", "header", "reference"] | |||
| def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True): | |||
| def __call__(self, image_list, ocr_res, scale_factor=3, | |||
| thr=0.2, batch_size=16, drop=True): | |||
| def __is_garbage(b): | |||
| patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$", | |||
| r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", | |||
| @@ -75,7 +78,8 @@ class LayoutRecognizer(Recognizer): | |||
| "top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor, | |||
| "page_number": pn, | |||
| } for b in lts] | |||
| lts = self.sort_Y_firstly(lts, np.mean([l["bottom"]-l["top"] for l in lts]) / 2) | |||
| lts = self.sort_Y_firstly(lts, np.mean( | |||
| [l["bottom"] - l["top"] for l in lts]) / 2) | |||
| lts = self.layouts_cleanup(bxs, lts) | |||
| page_layout.append(lts) | |||
| @@ -93,17 +97,20 @@ class LayoutRecognizer(Recognizer): | |||
| continue | |||
| ii = self.find_overlapped_with_threashold(bxs[i], lts_, | |||
| thr=0.4) | |||
| thr=0.4) | |||
| if ii is None: # belong to nothing | |||
| bxs[i]["layout_type"] = "" | |||
| i += 1 | |||
| continue | |||
| lts_[ii]["visited"] = True | |||
| keep_feats = [ | |||
| lts_[ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].size[1]*0.9/scale_factor, | |||
| lts_[ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].size[1]*0.1/scale_factor, | |||
| lts_[ | |||
| ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].size[1] * 0.9 / scale_factor, | |||
| lts_[ | |||
| ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].size[1] * 0.1 / scale_factor, | |||
| ] | |||
| if drop and lts_[ii]["type"] in self.garbage_layouts and not any(keep_feats): | |||
| if drop and lts_[ | |||
| ii]["type"] in self.garbage_layouts and not any(keep_feats): | |||
| if lts_[ii]["type"] not in garbages: | |||
| garbages[lts_[ii]["type"]] = [] | |||
| garbages[lts_[ii]["type"]].append(bxs[i]["text"]) | |||
| @@ -111,7 +118,8 @@ class LayoutRecognizer(Recognizer): | |||
| continue | |||
| bxs[i]["layoutno"] = f"{ty}-{ii}" | |||
| bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[ii]["type"]!="equation" else "figure" | |||
| bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[ | |||
| ii]["type"] != "equation" else "figure" | |||
| i += 1 | |||
| for lt in ["footer", "header", "reference", "figure caption", | |||
| @@ -120,7 +128,7 @@ class LayoutRecognizer(Recognizer): | |||
| # add box to figure layouts which has not text box | |||
| for i, lt in enumerate( | |||
| [lt for lt in lts if lt["type"] in ["figure","equation"]]): | |||
| [lt for lt in lts if lt["type"] in ["figure", "equation"]]): | |||
| if lt.get("visited"): | |||
| continue | |||
| lt = deepcopy(lt) | |||
| @@ -143,6 +151,3 @@ class LayoutRecognizer(Recognizer): | |||
| ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set] | |||
| return ocr_res, page_layout | |||
| @@ -63,6 +63,7 @@ class DecodeImage(object): | |||
| data['image'] = img | |||
| return data | |||
| class StandardizeImage(object): | |||
| """normalize image | |||
| Args: | |||
| @@ -707,4 +708,4 @@ def preprocess(im, preprocess_ops): | |||
| im, im_info = decode_image(im, im_info) | |||
| for operator in preprocess_ops: | |||
| im, im_info = operator(im, im_info) | |||
| return im, im_info | |||
| return im, im_info | |||
| @@ -11,12 +11,20 @@ | |||
| # limitations under the License. | |||
| # | |||
| import os, sys | |||
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../'))) | |||
| import numpy as np | |||
| import argparse | |||
| from deepdoc.vision import OCR, init_in_out | |||
| from deepdoc.vision.seeit import draw_box | |||
| from deepdoc.vision import OCR, init_in_out | |||
| import argparse | |||
| import numpy as np | |||
| import os | |||
| import sys | |||
| sys.path.insert( | |||
| 0, | |||
| os.path.abspath( | |||
| os.path.join( | |||
| os.path.dirname( | |||
| os.path.abspath(__file__)), | |||
| '../../'))) | |||
| def main(args): | |||
| ocr = OCR() | |||
| @@ -26,14 +34,14 @@ def main(args): | |||
| bxs = ocr(np.array(img)) | |||
| bxs = [(line[0], line[1][0]) for line in bxs] | |||
| bxs = [{ | |||
| "text": t, | |||
| "bbox": [b[0][0], b[0][1], b[1][0], b[-1][1]], | |||
| "type": "ocr", | |||
| "score": 1} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]] | |||
| "text": t, | |||
| "bbox": [b[0][0], b[0][1], b[1][0], b[-1][1]], | |||
| "type": "ocr", | |||
| "score": 1} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]] | |||
| img = draw_box(images[i], bxs, ["ocr"], 1.) | |||
| img.save(outputs[i], quality=95) | |||
| with open(outputs[i] + ".txt", "w+") as f: f.write("\n".join([o["text"] for o in bxs])) | |||
| with open(outputs[i] + ".txt", "w+") as f: | |||
| f.write("\n".join([o["text"] for o in bxs])) | |||
| if __name__ == "__main__": | |||
| @@ -42,6 +50,6 @@ if __name__ == "__main__": | |||
| help="Directory where to store images or PDFs, or a file path to a single image or PDF", | |||
| required=True) | |||
| parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './ocr_outputs'", | |||
| default="./ocr_outputs") | |||
| default="./ocr_outputs") | |||
| args = parser.parse_args() | |||
| main(args) | |||
| main(args) | |||
| @@ -11,24 +11,35 @@ | |||
| # limitations under the License. | |||
| # | |||
| import os, sys | |||
| from deepdoc.vision.seeit import draw_box | |||
| from deepdoc.vision import Recognizer, LayoutRecognizer, TableStructureRecognizer, OCR, init_in_out | |||
| from api.utils.file_utils import get_project_base_directory | |||
| import argparse | |||
| import os | |||
| import sys | |||
| import re | |||
| import numpy as np | |||
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../'))) | |||
| import argparse | |||
| from api.utils.file_utils import get_project_base_directory | |||
| from deepdoc.vision import Recognizer, LayoutRecognizer, TableStructureRecognizer, OCR, init_in_out | |||
| from deepdoc.vision.seeit import draw_box | |||
| sys.path.insert( | |||
| 0, | |||
| os.path.abspath( | |||
| os.path.join( | |||
| os.path.dirname( | |||
| os.path.abspath(__file__)), | |||
| '../../'))) | |||
| def main(args): | |||
| images, outputs = init_in_out(args) | |||
| if args.mode.lower() == "layout": | |||
| labels = LayoutRecognizer.labels | |||
| detr = Recognizer(labels, "layout", os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||
| detr = Recognizer( | |||
| labels, | |||
| "layout", | |||
| os.path.join( | |||
| get_project_base_directory(), | |||
| "rag/res/deepdoc/")) | |||
| if args.mode.lower() == "tsr": | |||
| labels = TableStructureRecognizer.labels | |||
| detr = TableStructureRecognizer() | |||
| @@ -39,7 +50,8 @@ def main(args): | |||
| if args.mode.lower() == "tsr": | |||
| #lyt = [t for t in lyt if t["type"] == "table column"] | |||
| html = get_table_html(images[i], lyt, ocr) | |||
| with open(outputs[i]+".html", "w+") as f: f.write(html) | |||
| with open(outputs[i] + ".html", "w+") as f: | |||
| f.write(html) | |||
| lyt = [{ | |||
| "type": t["label"], | |||
| "bbox": [t["x0"], t["top"], t["x1"], t["bottom"]], | |||
| @@ -58,7 +70,7 @@ def get_table_html(img, tb_cpns, ocr): | |||
| "bottom": b[-1][1], | |||
| "layout_type": "table", | |||
| "page_number": 0} for b, t in boxes if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]], | |||
| np.mean([b[-1][1]-b[0][1] for b,_ in boxes]) / 3 | |||
| np.mean([b[-1][1] - b[0][1] for b, _ in boxes]) / 3 | |||
| ) | |||
| def gather(kwd, fzy=10, ption=0.6): | |||
| @@ -117,7 +129,7 @@ def get_table_html(img, tb_cpns, ocr): | |||
| margin-bottom: 50px; | |||
| border: 1px solid #e1e1e1; | |||
| } | |||
| caption { | |||
| color: #6ac1ca; | |||
| font-size: 20px; | |||
| @@ -126,25 +138,25 @@ def get_table_html(img, tb_cpns, ocr): | |||
| font-weight: 600; | |||
| margin-bottom: 10px; | |||
| } | |||
| ._table_1nkzy_11 table { | |||
| width: 100%%; | |||
| border-collapse: collapse; | |||
| } | |||
| th { | |||
| color: #fff; | |||
| background-color: #6ac1ca; | |||
| } | |||
| td:hover { | |||
| background: #c1e8e8; | |||
| } | |||
| tr:nth-child(even) { | |||
| background-color: #f2f2f2; | |||
| } | |||
| ._table_1nkzy_11 th, | |||
| ._table_1nkzy_11 td { | |||
| text-align: center; | |||
| @@ -157,7 +169,7 @@ def get_table_html(img, tb_cpns, ocr): | |||
| %s | |||
| </body> | |||
| </html> | |||
| """% TableStructureRecognizer.construct_table(boxes, html=True) | |||
| """ % TableStructureRecognizer.construct_table(boxes, html=True) | |||
| return html | |||
| @@ -168,7 +180,10 @@ if __name__ == "__main__": | |||
| required=True) | |||
| parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './layouts_outputs'", | |||
| default="./layouts_outputs") | |||
| parser.add_argument('--threshold', help="A threshold to filter out detections. Default: 0.5", default=0.5) | |||
| parser.add_argument( | |||
| '--threshold', | |||
| help="A threshold to filter out detections. Default: 0.5", | |||
| default=0.5) | |||
| parser.add_argument('--mode', help="Task mode: layout recognition or table structure recognition", choices=["layout", "tsr"], | |||
| default="layout") | |||
| args = parser.parse_args() | |||
| @@ -44,7 +44,8 @@ class TableStructureRecognizer(Recognizer): | |||
| except Exception as e: | |||
| model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc") | |||
| super().__init__(self.labels, "tsr", model_dir)#os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||
| # os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||
| super().__init__(self.labels, "tsr", model_dir) | |||
| def __call__(self, images, thr=0.2): | |||
| tbls = super().__call__(images, thr) | |||
| @@ -138,7 +139,8 @@ class TableStructureRecognizer(Recognizer): | |||
| i = 0 | |||
| while i < len(boxes): | |||
| if TableStructureRecognizer.is_caption(boxes[i]): | |||
| if is_english: cap + " " | |||
| if is_english: | |||
| cap + " " | |||
| cap += boxes[i]["text"] | |||
| boxes.pop(i) | |||
| i -= 1 | |||
| @@ -164,7 +166,7 @@ class TableStructureRecognizer(Recognizer): | |||
| lst_r = rows[-1] | |||
| if lst_r[-1].get("R", "") != b.get("R", "") \ | |||
| or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2") | |||
| ): # new row | |||
| ): # new row | |||
| btm = b["bottom"] | |||
| b["rn"] += 1 | |||
| rows.append([b]) | |||
| @@ -214,9 +216,9 @@ class TableStructureRecognizer(Recognizer): | |||
| j += 1 | |||
| continue | |||
| f = (j > 0 and tbl[ii][j - 1] and tbl[ii] | |||
| [j - 1][0].get("text")) or j == 0 | |||
| [j - 1][0].get("text")) or j == 0 | |||
| ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii] | |||
| [j + 1][0].get("text")) or j + 1 >= len(tbl[ii]) | |||
| [j + 1][0].get("text")) or j + 1 >= len(tbl[ii]) | |||
| if f and ff: | |||
| j += 1 | |||
| continue | |||
| @@ -277,9 +279,9 @@ class TableStructureRecognizer(Recognizer): | |||
| i += 1 | |||
| continue | |||
| f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1] | |||
| [jj][0].get("text")) or i == 0 | |||
| [jj][0].get("text")) or i == 0 | |||
| ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1] | |||
| [jj][0].get("text")) or i + 1 >= len(tbl) | |||
| [jj][0].get("text")) or i + 1 >= len(tbl) | |||
| if f and ff: | |||
| i += 1 | |||
| continue | |||
| @@ -366,7 +368,8 @@ class TableStructureRecognizer(Recognizer): | |||
| continue | |||
| txt = "" | |||
| if arr: | |||
| h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10) | |||
| h = min(np.min([c["bottom"] - c["top"] | |||
| for c in arr]) / 2, 10) | |||
| txt = " ".join([c["text"] | |||
| for c in Recognizer.sort_Y_firstly(arr, h)]) | |||
| txts.append(txt) | |||
| @@ -438,8 +441,8 @@ class TableStructureRecognizer(Recognizer): | |||
| else "") + headers[j - 1][k] | |||
| else: | |||
| headers[j][k] = headers[j - 1][k] \ | |||
| + (de if headers[j - 1][k] else "") \ | |||
| + headers[j][k] | |||
| + (de if headers[j - 1][k] else "") \ | |||
| + headers[j][k] | |||
| logging.debug( | |||
| f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}") | |||
| @@ -48,10 +48,12 @@ class Pdf(PdfParser): | |||
| callback(0.8, "Text extraction finished") | |||
| return [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno","")) for b in self.boxes], tbls | |||
| return [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) | |||
| for b in self.boxes], tbls | |||
| def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): | |||
| def chunk(filename, binary=None, from_page=0, to_page=100000, | |||
| lang="Chinese", callback=None, **kwargs): | |||
| """ | |||
| Supported file formats are docx, pdf, txt. | |||
| Since a book is long and not all the parts are useful, if it's a PDF, | |||
| @@ -63,48 +65,63 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| } | |||
| doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) | |||
| pdf_parser = None | |||
| sections,tbls = [], [] | |||
| sections, tbls = [], [] | |||
| if re.search(r"\.docx?$", filename, re.IGNORECASE): | |||
| callback(0.1, "Start to parse.") | |||
| doc_parser = DocxParser() | |||
| # TODO: table of contents need to be removed | |||
| sections, tbls = doc_parser(binary if binary else filename, from_page=from_page, to_page=to_page) | |||
| remove_contents_table(sections, eng=is_english(random_choices([t for t,_ in sections], k=200))) | |||
| sections, tbls = doc_parser( | |||
| binary if binary else filename, from_page=from_page, to_page=to_page) | |||
| remove_contents_table(sections, eng=is_english( | |||
| random_choices([t for t, _ in sections], k=200))) | |||
| callback(0.8, "Finish parsing.") | |||
| elif re.search(r"\.pdf$", filename, re.IGNORECASE): | |||
| pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser() | |||
| pdf_parser = Pdf() if kwargs.get( | |||
| "parser_config", {}).get( | |||
| "layout_recognize", True) else PlainParser() | |||
| sections, tbls = pdf_parser(filename if not binary else binary, | |||
| from_page=from_page, to_page=to_page, callback=callback) | |||
| from_page=from_page, to_page=to_page, callback=callback) | |||
| elif re.search(r"\.txt$", filename, re.IGNORECASE): | |||
| callback(0.1, "Start to parse.") | |||
| txt = "" | |||
| if binary:txt = binary.decode("utf-8") | |||
| if binary: | |||
| txt = binary.decode("utf-8") | |||
| else: | |||
| with open(filename, "r") as f: | |||
| while True: | |||
| l = f.readline() | |||
| if not l:break | |||
| if not l: | |||
| break | |||
| txt += l | |||
| sections = txt.split("\n") | |||
| sections = [(l,"") for l in sections if l] | |||
| remove_contents_table(sections, eng = is_english(random_choices([t for t,_ in sections], k=200))) | |||
| sections = [(l, "") for l in sections if l] | |||
| remove_contents_table(sections, eng=is_english( | |||
| random_choices([t for t, _ in sections], k=200))) | |||
| callback(0.8, "Finish parsing.") | |||
| else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") | |||
| else: | |||
| raise NotImplementedError( | |||
| "file type not supported yet(docx, pdf, txt supported)") | |||
| make_colon_as_title(sections) | |||
| bull = bullets_category([t for t in random_choices([t for t,_ in sections], k=100)]) | |||
| bull = bullets_category( | |||
| [t for t in random_choices([t for t, _ in sections], k=100)]) | |||
| if bull >= 0: | |||
| chunks = ["\n".join(ck) for ck in hierarchical_merge(bull, sections, 3)] | |||
| chunks = ["\n".join(ck) | |||
| for ck in hierarchical_merge(bull, sections, 3)] | |||
| else: | |||
| sections = [s.split("@") for s,_ in sections] | |||
| sections = [(pr[0], "@"+pr[1]) for pr in sections if len(pr)==2] | |||
| chunks = naive_merge(sections, kwargs.get("chunk_token_num", 256), kwargs.get("delimer", "\n。;!?")) | |||
| sections = [s.split("@") for s, _ in sections] | |||
| sections = [(pr[0], "@" + pr[1]) for pr in sections if len(pr) == 2] | |||
| chunks = naive_merge( | |||
| sections, kwargs.get( | |||
| "chunk_token_num", 256), kwargs.get( | |||
| "delimer", "\n。;!?")) | |||
| # is it English | |||
| eng = lang.lower() == "english"#is_english(random_choices([t for t, _ in sections], k=218)) | |||
| # is_english(random_choices([t for t, _ in sections], k=218)) | |||
| eng = lang.lower() == "english" | |||
| res = tokenize_table(tbls, doc, eng) | |||
| res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser)) | |||
| @@ -114,6 +131,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| if __name__ == "__main__": | |||
| import sys | |||
| def dummy(prog=None, msg=""): | |||
| pass | |||
| chunk(sys.argv[1], from_page=1, to_page=10, callback=dummy) | |||
| @@ -35,8 +35,10 @@ class Docx(DocxParser): | |||
| pn = 0 | |||
| lines = [] | |||
| for p in self.doc.paragraphs: | |||
| if pn > to_page:break | |||
| if from_page <= pn < to_page and p.text.strip(): lines.append(self.__clean(p.text)) | |||
| if pn > to_page: | |||
| break | |||
| if from_page <= pn < to_page and p.text.strip(): | |||
| lines.append(self.__clean(p.text)) | |||
| for run in p.runs: | |||
| if 'lastRenderedPageBreak' in run._element.xml: | |||
| pn += 1 | |||
| @@ -63,15 +65,18 @@ class Pdf(PdfParser): | |||
| start = timer() | |||
| self._layouts_rec(zoomin) | |||
| callback(0.67, "Layout analysis finished") | |||
| cron_logger.info("paddle layouts:".format((timer()-start)/(self.total_page+0.1))) | |||
| cron_logger.info("paddle layouts:".format( | |||
| (timer() - start) / (self.total_page + 0.1))) | |||
| self._naive_vertical_merge() | |||
| callback(0.8, "Text extraction finished") | |||
| return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], None | |||
| return [(b["text"], self._line_tag(b, zoomin)) | |||
| for b in self.boxes], None | |||
| def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): | |||
| def chunk(filename, binary=None, from_page=0, to_page=100000, | |||
| lang="Chinese", callback=None, **kwargs): | |||
| """ | |||
| Supported file formats are docx, pdf, txt. | |||
| """ | |||
| @@ -89,41 +94,50 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| callback(0.8, "Finish parsing.") | |||
| elif re.search(r"\.pdf$", filename, re.IGNORECASE): | |||
| pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser() | |||
| for txt, poss in pdf_parser(filename if not binary else binary, | |||
| from_page=from_page, to_page=to_page, callback=callback)[0]: | |||
| sections.append(txt + poss) | |||
| pdf_parser = Pdf() if kwargs.get( | |||
| "parser_config", {}).get( | |||
| "layout_recognize", True) else PlainParser() | |||
| for txt, poss in pdf_parser(filename if not binary else binary, | |||
| from_page=from_page, to_page=to_page, callback=callback)[0]: | |||
| sections.append(txt + poss) | |||
| elif re.search(r"\.txt$", filename, re.IGNORECASE): | |||
| callback(0.1, "Start to parse.") | |||
| txt = "" | |||
| if binary:txt = binary.decode("utf-8") | |||
| if binary: | |||
| txt = binary.decode("utf-8") | |||
| else: | |||
| with open(filename, "r") as f: | |||
| while True: | |||
| l = f.readline() | |||
| if not l:break | |||
| if not l: | |||
| break | |||
| txt += l | |||
| sections = txt.split("\n") | |||
| sections = [l for l in sections if l] | |||
| callback(0.8, "Finish parsing.") | |||
| else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") | |||
| else: | |||
| raise NotImplementedError( | |||
| "file type not supported yet(docx, pdf, txt supported)") | |||
| # is it English | |||
| eng = lang.lower() == "english"#is_english(sections) | |||
| eng = lang.lower() == "english" # is_english(sections) | |||
| # Remove 'Contents' part | |||
| remove_contents_table(sections, eng) | |||
| make_colon_as_title(sections) | |||
| bull = bullets_category(sections) | |||
| chunks = hierarchical_merge(bull, sections, 3) | |||
| if not chunks: callback(0.99, "No chunk parsed out.") | |||
| if not chunks: | |||
| callback(0.99, "No chunk parsed out.") | |||
| return tokenize_chunks(["\n".join(ck) for ck in chunks], doc, eng, pdf_parser) | |||
| return tokenize_chunks(["\n".join(ck) | |||
| for ck in chunks], doc, eng, pdf_parser) | |||
| if __name__ == "__main__": | |||
| import sys | |||
| def dummy(prog=None, msg=""): | |||
| pass | |||
| chunk(sys.argv[1], callback=dummy) | |||
| @@ -25,10 +25,10 @@ class Pdf(PdfParser): | |||
| callback | |||
| ) | |||
| callback(msg="OCR finished.") | |||
| #for bb in self.boxes: | |||
| # for bb in self.boxes: | |||
| # for b in bb: | |||
| # print(b) | |||
| print("OCR:", timer()-start) | |||
| print("OCR:", timer() - start) | |||
| self._layouts_rec(zoomin) | |||
| callback(0.65, "Layout analysis finished.") | |||
| @@ -45,30 +45,35 @@ class Pdf(PdfParser): | |||
| for b in self.boxes: | |||
| b["text"] = re.sub(r"([\t ]|\u3000){2,}", " ", b["text"].strip()) | |||
| return [(b["text"], b.get("layout_no", ""), self.get_position(b, zoomin)) for i, b in enumerate(self.boxes)], tbls | |||
| return [(b["text"], b.get("layout_no", ""), self.get_position(b, zoomin)) | |||
| for i, b in enumerate(self.boxes)], tbls | |||
| def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): | |||
| def chunk(filename, binary=None, from_page=0, to_page=100000, | |||
| lang="Chinese", callback=None, **kwargs): | |||
| """ | |||
| Only pdf is supported. | |||
| """ | |||
| pdf_parser = None | |||
| if re.search(r"\.pdf$", filename, re.IGNORECASE): | |||
| pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser() | |||
| pdf_parser = Pdf() if kwargs.get( | |||
| "parser_config", {}).get( | |||
| "layout_recognize", True) else PlainParser() | |||
| sections, tbls = pdf_parser(filename if not binary else binary, | |||
| from_page=from_page, to_page=to_page, callback=callback) | |||
| if sections and len(sections[0])<3: sections = [(t, l, [[0]*5]) for t, l in sections] | |||
| from_page=from_page, to_page=to_page, callback=callback) | |||
| if sections and len(sections[0]) < 3: | |||
| sections = [(t, l, [[0] * 5]) for t, l in sections] | |||
| else: raise NotImplementedError("file type not supported yet(pdf supported)") | |||
| else: | |||
| raise NotImplementedError("file type not supported yet(pdf supported)") | |||
| doc = { | |||
| "docnm_kwd": filename | |||
| } | |||
| doc["title_tks"] = huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", doc["docnm_kwd"])) | |||
| doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) | |||
| # is it English | |||
| eng = lang.lower() == "english"#pdf_parser.is_english | |||
| eng = lang.lower() == "english" # pdf_parser.is_english | |||
| # set pivot using the most frequent type of title, | |||
| # then merge between 2 pivot | |||
| @@ -79,7 +84,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| for txt, _, _ in sections: | |||
| for t, lvl in pdf_parser.outlines: | |||
| tks = set([t[i] + t[i + 1] for i in range(len(t) - 1)]) | |||
| tks_ = set([txt[i] + txt[i + 1] for i in range(min(len(t), len(txt) - 1))]) | |||
| tks_ = set([txt[i] + txt[i + 1] | |||
| for i in range(min(len(t), len(txt) - 1))]) | |||
| if len(set(tks & tks_)) / max([len(tks), len(tks_), 1]) > 0.8: | |||
| levels.append(lvl) | |||
| break | |||
| @@ -87,24 +93,27 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| levels.append(max_lvl + 1) | |||
| else: | |||
| bull = bullets_category([txt for txt,_,_ in sections]) | |||
| most_level, levels = title_frequency(bull, [(txt, l) for txt, l, poss in sections]) | |||
| bull = bullets_category([txt for txt, _, _ in sections]) | |||
| most_level, levels = title_frequency( | |||
| bull, [(txt, l) for txt, l, poss in sections]) | |||
| assert len(sections) == len(levels) | |||
| sec_ids = [] | |||
| sid = 0 | |||
| for i, lvl in enumerate(levels): | |||
| if lvl <= most_level and i > 0 and lvl != levels[i - 1]: sid += 1 | |||
| if lvl <= most_level and i > 0 and lvl != levels[i - 1]: | |||
| sid += 1 | |||
| sec_ids.append(sid) | |||
| # print(lvl, self.boxes[i]["text"], most_level, sid) | |||
| sections = [(txt, sec_ids[i], poss) for i, (txt, _, poss) in enumerate(sections)] | |||
| sections = [(txt, sec_ids[i], poss) | |||
| for i, (txt, _, poss) in enumerate(sections)] | |||
| for (img, rows), poss in tbls: | |||
| sections.append((rows if isinstance(rows, str) else rows[0], -1, | |||
| [(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss])) | |||
| def tag(pn, left, right, top, bottom): | |||
| if pn+left+right+top+bottom == 0: | |||
| if pn + left + right + top + bottom == 0: | |||
| return "" | |||
| return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \ | |||
| .format(pn, left, right, top, bottom) | |||
| @@ -112,7 +121,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| chunks = [] | |||
| last_sid = -2 | |||
| tk_cnt = 0 | |||
| for txt, sec_id, poss in sorted(sections, key=lambda x: (x[-1][0][0], x[-1][0][3], x[-1][0][1])): | |||
| for txt, sec_id, poss in sorted(sections, key=lambda x: ( | |||
| x[-1][0][0], x[-1][0][3], x[-1][0][1])): | |||
| poss = "\t".join([tag(*pos) for pos in poss]) | |||
| if tk_cnt < 2048 and (sec_id == last_sid or sec_id == -1): | |||
| if chunks: | |||
| @@ -121,16 +131,17 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| continue | |||
| chunks.append(txt + poss) | |||
| tk_cnt = num_tokens_from_string(txt) | |||
| if sec_id > -1: last_sid = sec_id | |||
| if sec_id > -1: | |||
| last_sid = sec_id | |||
| res = tokenize_table(tbls, doc, eng) | |||
| res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser)) | |||
| return res | |||
| if __name__ == "__main__": | |||
| import sys | |||
| def dummy(prog=None, msg=""): | |||
| pass | |||
| chunk(sys.argv[1], callback=dummy) | |||
| @@ -44,11 +44,14 @@ class Pdf(PdfParser): | |||
| tbls = self._extract_table_figure(True, zoomin, True, True) | |||
| self._naive_vertical_merge() | |||
| cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1))) | |||
| return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls | |||
| cron_logger.info("paddle layouts:".format( | |||
| (timer() - start) / (self.total_page + 0.1))) | |||
| return [(b["text"], self._line_tag(b, zoomin)) | |||
| for b in self.boxes], tbls | |||
| def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): | |||
| def chunk(filename, binary=None, from_page=0, to_page=100000, | |||
| lang="Chinese", callback=None, **kwargs): | |||
| """ | |||
| Supported file formats are docx, pdf, excel, txt. | |||
| This method apply the naive ways to chunk files. | |||
| @@ -56,8 +59,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'. | |||
| """ | |||
| eng = lang.lower() == "english"#is_english(cks) | |||
| parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize": True}) | |||
| eng = lang.lower() == "english" # is_english(cks) | |||
| parser_config = kwargs.get( | |||
| "parser_config", { | |||
| "chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize": True}) | |||
| doc = { | |||
| "docnm_kwd": filename, | |||
| "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename)) | |||
| @@ -73,9 +78,10 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| callback(0.8, "Finish parsing.") | |||
| elif re.search(r"\.pdf$", filename, re.IGNORECASE): | |||
| pdf_parser = Pdf() if parser_config["layout_recognize"] else PlainParser() | |||
| pdf_parser = Pdf( | |||
| ) if parser_config["layout_recognize"] else PlainParser() | |||
| sections, tbls = pdf_parser(filename if not binary else binary, | |||
| from_page=from_page, to_page=to_page, callback=callback) | |||
| from_page=from_page, to_page=to_page, callback=callback) | |||
| res = tokenize_table(tbls, doc, eng) | |||
| elif re.search(r"\.xlsx?$", filename, re.IGNORECASE): | |||
| @@ -92,16 +98,21 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| with open(filename, "r") as f: | |||
| while True: | |||
| l = f.readline() | |||
| if not l: break | |||
| if not l: | |||
| break | |||
| txt += l | |||
| sections = txt.split("\n") | |||
| sections = [(l, "") for l in sections if l] | |||
| callback(0.8, "Finish parsing.") | |||
| else: | |||
| raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") | |||
| raise NotImplementedError( | |||
| "file type not supported yet(docx, pdf, txt supported)") | |||
| chunks = naive_merge(sections, parser_config.get("chunk_token_num", 128), parser_config.get("delimiter", "\n!?。;!?")) | |||
| chunks = naive_merge( | |||
| sections, parser_config.get( | |||
| "chunk_token_num", 128), parser_config.get( | |||
| "delimiter", "\n!?。;!?")) | |||
| res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser)) | |||
| return res | |||
| @@ -110,9 +121,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| if __name__ == "__main__": | |||
| import sys | |||
| def dummy(prog=None, msg=""): | |||
| pass | |||
| chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) | |||
| @@ -41,20 +41,23 @@ class Pdf(PdfParser): | |||
| tbls = self._extract_table_figure(True, zoomin, True, True) | |||
| self._concat_downward() | |||
| sections = [(b["text"], self.get_position(b, zoomin)) for i, b in enumerate(self.boxes)] | |||
| sections = [(b["text"], self.get_position(b, zoomin)) | |||
| for i, b in enumerate(self.boxes)] | |||
| for (img, rows), poss in tbls: | |||
| sections.append((rows if isinstance(rows, str) else rows[0], | |||
| [(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss])) | |||
| return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (x[-1][0][0], x[-1][0][3], x[-1][0][1]))], None | |||
| return [(txt, "") for txt, _ in sorted(sections, key=lambda x: ( | |||
| x[-1][0][0], x[-1][0][3], x[-1][0][1]))], None | |||
| def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): | |||
| def chunk(filename, binary=None, from_page=0, to_page=100000, | |||
| lang="Chinese", callback=None, **kwargs): | |||
| """ | |||
| Supported file formats are docx, pdf, excel, txt. | |||
| One file forms a chunk which maintains original text order. | |||
| """ | |||
| eng = lang.lower() == "english"#is_english(cks) | |||
| eng = lang.lower() == "english" # is_english(cks) | |||
| if re.search(r"\.docx?$", filename, re.IGNORECASE): | |||
| callback(0.1, "Start to parse.") | |||
| @@ -62,8 +65,11 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| callback(0.8, "Finish parsing.") | |||
| elif re.search(r"\.pdf$", filename, re.IGNORECASE): | |||
| pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainParser() | |||
| sections, _ = pdf_parser(filename if not binary else binary, to_page=to_page, callback=callback) | |||
| pdf_parser = Pdf() if kwargs.get( | |||
| "parser_config", {}).get( | |||
| "layout_recognize", True) else PlainParser() | |||
| sections, _ = pdf_parser( | |||
| filename if not binary else binary, to_page=to_page, callback=callback) | |||
| sections = [s for s, _ in sections if s] | |||
| elif re.search(r"\.xlsx?$", filename, re.IGNORECASE): | |||
| @@ -80,14 +86,16 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| with open(filename, "r") as f: | |||
| while True: | |||
| l = f.readline() | |||
| if not l: break | |||
| if not l: | |||
| break | |||
| txt += l | |||
| sections = txt.split("\n") | |||
| sections = [s for s in sections if s] | |||
| callback(0.8, "Finish parsing.") | |||
| else: | |||
| raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") | |||
| raise NotImplementedError( | |||
| "file type not supported yet(docx, pdf, txt supported)") | |||
| doc = { | |||
| "docnm_kwd": filename, | |||
| @@ -101,9 +109,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| if __name__ == "__main__": | |||
| import sys | |||
| def dummy(prog=None, msg=""): | |||
| pass | |||
| chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) | |||
| @@ -67,11 +67,11 @@ class Pdf(PdfParser): | |||
| if from_page > 0: | |||
| return { | |||
| "title":"", | |||
| "title": "", | |||
| "authors": "", | |||
| "abstract": "", | |||
| "sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes if | |||
| re.match(r"(text|title)", b.get("layoutno", "text"))], | |||
| re.match(r"(text|title)", b.get("layoutno", "text"))], | |||
| "tables": tbls | |||
| } | |||
| # get title and authors | |||
| @@ -87,7 +87,8 @@ class Pdf(PdfParser): | |||
| title = "" | |||
| break | |||
| for j in range(3): | |||
| if _begin(self.boxes[i + j]["text"]): break | |||
| if _begin(self.boxes[i + j]["text"]): | |||
| break | |||
| authors.append(self.boxes[i + j]["text"]) | |||
| break | |||
| break | |||
| @@ -107,10 +108,15 @@ class Pdf(PdfParser): | |||
| abstr = txt + self._line_tag(self.boxes[i], zoomin) | |||
| i += 1 | |||
| break | |||
| if not abstr: i = 0 | |||
| if not abstr: | |||
| i = 0 | |||
| callback(0.8, "Page {}~{}: Text merging finished".format(from_page, min(to_page, self.total_page))) | |||
| for b in self.boxes: print(b["text"], b.get("layoutno")) | |||
| callback( | |||
| 0.8, "Page {}~{}: Text merging finished".format( | |||
| from_page, min( | |||
| to_page, self.total_page))) | |||
| for b in self.boxes: | |||
| print(b["text"], b.get("layoutno")) | |||
| print(tbls) | |||
| return { | |||
| @@ -118,19 +124,20 @@ class Pdf(PdfParser): | |||
| "authors": " ".join(authors), | |||
| "abstract": abstr, | |||
| "sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes[i:] if | |||
| re.match(r"(text|title)", b.get("layoutno", "text"))], | |||
| re.match(r"(text|title)", b.get("layoutno", "text"))], | |||
| "tables": tbls | |||
| } | |||
| def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): | |||
| def chunk(filename, binary=None, from_page=0, to_page=100000, | |||
| lang="Chinese", callback=None, **kwargs): | |||
| """ | |||
| Only pdf is supported. | |||
| The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly. | |||
| """ | |||
| pdf_parser = None | |||
| if re.search(r"\.pdf$", filename, re.IGNORECASE): | |||
| if not kwargs.get("parser_config",{}).get("layout_recognize", True): | |||
| if not kwargs.get("parser_config", {}).get("layout_recognize", True): | |||
| pdf_parser = PlainParser() | |||
| paper = { | |||
| "title": filename, | |||
| @@ -143,14 +150,15 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| pdf_parser = Pdf() | |||
| paper = pdf_parser(filename if not binary else binary, | |||
| from_page=from_page, to_page=to_page, callback=callback) | |||
| else: raise NotImplementedError("file type not supported yet(pdf supported)") | |||
| else: | |||
| raise NotImplementedError("file type not supported yet(pdf supported)") | |||
| doc = {"docnm_kwd": filename, "authors_tks": huqie.qie(paper["authors"]), | |||
| "title_tks": huqie.qie(paper["title"] if paper["title"] else filename)} | |||
| doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) | |||
| doc["authors_sm_tks"] = huqie.qieqie(doc["authors_tks"]) | |||
| # is it English | |||
| eng = lang.lower() == "english"#pdf_parser.is_english | |||
| eng = lang.lower() == "english" # pdf_parser.is_english | |||
| print("It's English.....", eng) | |||
| res = tokenize_table(paper["tables"], doc, eng) | |||
| @@ -160,7 +168,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| txt = pdf_parser.remove_tag(paper["abstract"]) | |||
| d["important_kwd"] = ["abstract", "总结", "概括", "summary", "summarize"] | |||
| d["important_tks"] = " ".join(d["important_kwd"]) | |||
| d["image"], poss = pdf_parser.crop(paper["abstract"], need_position=True) | |||
| d["image"], poss = pdf_parser.crop( | |||
| paper["abstract"], need_position=True) | |||
| add_positions(d, poss) | |||
| tokenize(d, txt, eng) | |||
| res.append(d) | |||
| @@ -174,7 +183,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| sec_ids = [] | |||
| sid = 0 | |||
| for i, lvl in enumerate(levels): | |||
| if lvl <= most_level and i > 0 and lvl != levels[i-1]: sid += 1 | |||
| if lvl <= most_level and i > 0 and lvl != levels[i - 1]: | |||
| sid += 1 | |||
| sec_ids.append(sid) | |||
| print(lvl, sorted_sections[i][0], most_level, sid) | |||
| @@ -190,6 +200,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser)) | |||
| return res | |||
| """ | |||
| readed = [0] * len(paper["lines"]) | |||
| # find colon firstly | |||
| @@ -212,7 +223,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| for k in range(j, i): readed[k] = True | |||
| txt = txt[::-1] | |||
| if eng: | |||
| r = re.search(r"(.*?) ([\.;?!]|$)", txt) | |||
| r = re.search(r"(.*?) ([\\.;?!]|$)", txt) | |||
| txt = r.group(1)[::-1] if r else txt[::-1] | |||
| else: | |||
| r = re.search(r"(.*?) ([。?;!]|$)", txt) | |||
| @@ -270,6 +281,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| if __name__ == "__main__": | |||
| import sys | |||
| def dummy(prog=None, msg=""): | |||
| pass | |||
| chunk(sys.argv[1], callback=dummy) | |||
| @@ -33,9 +33,12 @@ class Ppt(PptParser): | |||
| with slides.Presentation(BytesIO(fnm)) as presentation: | |||
| for i, slide in enumerate(presentation.slides[from_page: to_page]): | |||
| buffered = BytesIO() | |||
| slide.get_thumbnail(0.5, 0.5).save(buffered, drawing.imaging.ImageFormat.jpeg) | |||
| slide.get_thumbnail( | |||
| 0.5, 0.5).save( | |||
| buffered, drawing.imaging.ImageFormat.jpeg) | |||
| imgs.append(Image.open(buffered)) | |||
| assert len(imgs) == len(txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts)) | |||
| assert len(imgs) == len( | |||
| txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts)) | |||
| callback(0.9, "Image extraction finished") | |||
| self.is_english = is_english(txts) | |||
| return [(txts[i], imgs[i]) for i in range(len(txts))] | |||
| @@ -47,25 +50,34 @@ class Pdf(PdfParser): | |||
| def __garbage(self, txt): | |||
| txt = txt.lower().strip() | |||
| if re.match(r"[0-9\.,%/-]+$", txt): return True | |||
| if len(txt) < 3:return True | |||
| if re.match(r"[0-9\.,%/-]+$", txt): | |||
| return True | |||
| if len(txt) < 3: | |||
| return True | |||
| return False | |||
| def __call__(self, filename, binary=None, from_page=0, to_page=100000, zoomin=3, callback=None): | |||
| def __call__(self, filename, binary=None, from_page=0, | |||
| to_page=100000, zoomin=3, callback=None): | |||
| callback(msg="OCR is running...") | |||
| self.__images__(filename if not binary else binary, zoomin, from_page, to_page, callback) | |||
| callback(0.8, "Page {}~{}: OCR finished".format(from_page, min(to_page, self.total_page))) | |||
| assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(len(self.boxes), len(self.page_images)) | |||
| self.__images__(filename if not binary else binary, | |||
| zoomin, from_page, to_page, callback) | |||
| callback(0.8, "Page {}~{}: OCR finished".format( | |||
| from_page, min(to_page, self.total_page))) | |||
| assert len(self.boxes) == len(self.page_images), "{} vs. {}".format( | |||
| len(self.boxes), len(self.page_images)) | |||
| res = [] | |||
| for i in range(len(self.boxes)): | |||
| lines = "\n".join([b["text"] for b in self.boxes[i] if not self.__garbage(b["text"])]) | |||
| lines = "\n".join([b["text"] for b in self.boxes[i] | |||
| if not self.__garbage(b["text"])]) | |||
| res.append((lines, self.page_images[i])) | |||
| callback(0.9, "Page {}~{}: Parsing finished".format(from_page, min(to_page, self.total_page))) | |||
| callback(0.9, "Page {}~{}: Parsing finished".format( | |||
| from_page, min(to_page, self.total_page))) | |||
| return res | |||
| class PlainPdf(PlainParser): | |||
| def __call__(self, filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs): | |||
| def __call__(self, filename, binary=None, from_page=0, | |||
| to_page=100000, callback=None, **kwargs): | |||
| self.pdf = pdf2_read(filename if not binary else BytesIO(binary)) | |||
| page_txt = [] | |||
| for page in self.pdf.pages[from_page: to_page]: | |||
| @@ -74,7 +86,8 @@ class PlainPdf(PlainParser): | |||
| return [(txt, None) for txt in page_txt] | |||
| def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): | |||
| def chunk(filename, binary=None, from_page=0, to_page=100000, | |||
| lang="Chinese", callback=None, **kwargs): | |||
| """ | |||
| The supported file formats are pdf, pptx. | |||
| Every page will be treated as a chunk. And the thumbnail of every page will be stored. | |||
| @@ -89,35 +102,42 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| res = [] | |||
| if re.search(r"\.pptx?$", filename, re.IGNORECASE): | |||
| ppt_parser = Ppt() | |||
| for pn, (txt,img) in enumerate(ppt_parser(filename if not binary else binary, from_page, 1000000, callback)): | |||
| for pn, (txt, img) in enumerate(ppt_parser( | |||
| filename if not binary else binary, from_page, 1000000, callback)): | |||
| d = copy.deepcopy(doc) | |||
| pn += from_page | |||
| d["image"] = img | |||
| d["page_num_int"] = [pn+1] | |||
| d["page_num_int"] = [pn + 1] | |||
| d["top_int"] = [0] | |||
| d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])] | |||
| tokenize(d, txt, eng) | |||
| res.append(d) | |||
| return res | |||
| elif re.search(r"\.pdf$", filename, re.IGNORECASE): | |||
| pdf_parser = Pdf() if kwargs.get("parser_config",{}).get("layout_recognize", True) else PlainPdf() | |||
| for pn, (txt,img) in enumerate(pdf_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback)): | |||
| pdf_parser = Pdf() if kwargs.get( | |||
| "parser_config", {}).get( | |||
| "layout_recognize", True) else PlainPdf() | |||
| for pn, (txt, img) in enumerate(pdf_parser(filename, binary, | |||
| from_page=from_page, to_page=to_page, callback=callback)): | |||
| d = copy.deepcopy(doc) | |||
| pn += from_page | |||
| if img: d["image"] = img | |||
| d["page_num_int"] = [pn+1] | |||
| if img: | |||
| d["image"] = img | |||
| d["page_num_int"] = [pn + 1] | |||
| d["top_int"] = [0] | |||
| d["position_int"] = [(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)] | |||
| d["position_int"] = [ | |||
| (pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)] | |||
| tokenize(d, txt, eng) | |||
| res.append(d) | |||
| return res | |||
| raise NotImplementedError("file type not supported yet(pptx, pdf supported)") | |||
| raise NotImplementedError( | |||
| "file type not supported yet(pptx, pdf supported)") | |||
| if __name__== "__main__": | |||
| if __name__ == "__main__": | |||
| import sys | |||
| def dummy(a, b): | |||
| pass | |||
| chunk(sys.argv[1], callback=dummy) | |||
| @@ -27,6 +27,8 @@ from rag.utils import rmSpace | |||
| forbidden_select_fields4resume = [ | |||
| "name_pinyin_kwd", "edu_first_fea_kwd", "degree_kwd", "sch_rank_kwd", "edu_fea_kwd" | |||
| ] | |||
| def remote_call(filename, binary): | |||
| q = { | |||
| "header": { | |||
| @@ -48,18 +50,22 @@ def remote_call(filename, binary): | |||
| } | |||
| for _ in range(3): | |||
| try: | |||
| resume = requests.post("http://127.0.0.1:61670/tog", data=json.dumps(q)) | |||
| resume = requests.post( | |||
| "http://127.0.0.1:61670/tog", | |||
| data=json.dumps(q)) | |||
| resume = resume.json()["response"]["results"] | |||
| resume = refactor(resume) | |||
| for k in ["education", "work", "project", "training", "skill", "certificate", "language"]: | |||
| if not resume.get(k) and k in resume: del resume[k] | |||
| for k in ["education", "work", "project", | |||
| "training", "skill", "certificate", "language"]: | |||
| if not resume.get(k) and k in resume: | |||
| del resume[k] | |||
| resume = step_one.refactor(pd.DataFrame([{"resume_content": json.dumps(resume), "tob_resume_id": "x", | |||
| "updated_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}])) | |||
| "updated_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}])) | |||
| resume = step_two.parse(resume) | |||
| return resume | |||
| except Exception as e: | |||
| cron_logger.error("Resume parser error: "+str(e)) | |||
| cron_logger.error("Resume parser error: " + str(e)) | |||
| return {} | |||
| @@ -144,10 +150,13 @@ def chunk(filename, binary=None, callback=None, **kwargs): | |||
| doc["content_ltks"] = huqie.qie(doc["content_with_weight"]) | |||
| doc["content_sm_ltks"] = huqie.qieqie(doc["content_ltks"]) | |||
| for n, _ in field_map.items(): | |||
| if n not in resume:continue | |||
| if isinstance(resume[n], list) and (len(resume[n]) == 1 or n not in forbidden_select_fields4resume): | |||
| if n not in resume: | |||
| continue | |||
| if isinstance(resume[n], list) and ( | |||
| len(resume[n]) == 1 or n not in forbidden_select_fields4resume): | |||
| resume[n] = resume[n][0] | |||
| if n.find("_tks")>0: resume[n] = huqie.qieqie(resume[n]) | |||
| if n.find("_tks") > 0: | |||
| resume[n] = huqie.qieqie(resume[n]) | |||
| doc[n] = resume[n] | |||
| print(doc) | |||
| @@ -25,7 +25,8 @@ from deepdoc.parser import ExcelParser | |||
| class Excel(ExcelParser): | |||
| def __call__(self, fnm, binary=None, from_page=0, to_page=10000000000, callback=None): | |||
| def __call__(self, fnm, binary=None, from_page=0, | |||
| to_page=10000000000, callback=None): | |||
| if not binary: | |||
| wb = load_workbook(fnm) | |||
| else: | |||
| @@ -48,8 +49,10 @@ class Excel(ExcelParser): | |||
| data = [] | |||
| for i, r in enumerate(rows[1:]): | |||
| rn += 1 | |||
| if rn-1 < from_page:continue | |||
| if rn -1>=to_page: break | |||
| if rn - 1 < from_page: | |||
| continue | |||
| if rn - 1 >= to_page: | |||
| break | |||
| row = [ | |||
| cell.value for ii, | |||
| cell in enumerate(r) if ii not in missed] | |||
| @@ -60,7 +63,7 @@ class Excel(ExcelParser): | |||
| done += 1 | |||
| res.append(pd.DataFrame(np.array(data), columns=headers)) | |||
| callback(0.3, ("Extract records: {}~{}".format(from_page+1, min(to_page, from_page+rn)) + ( | |||
| callback(0.3, ("Extract records: {}~{}".format(from_page + 1, min(to_page, from_page + rn)) + ( | |||
| f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) | |||
| return res | |||
| @@ -73,7 +76,8 @@ def trans_datatime(s): | |||
| def trans_bool(s): | |||
| if re.match(r"(true|yes|是|\*|✓|✔|☑|✅|√)$", str(s).strip(), flags=re.IGNORECASE): | |||
| if re.match(r"(true|yes|是|\*|✓|✔|☑|✅|√)$", | |||
| str(s).strip(), flags=re.IGNORECASE): | |||
| return "yes" | |||
| if re.match(r"(false|no|否|⍻|×)$", str(s).strip(), flags=re.IGNORECASE): | |||
| return "no" | |||
| @@ -107,13 +111,14 @@ def column_data_type(arr): | |||
| arr[i] = trans[ty](str(arr[i])) | |||
| except Exception as e: | |||
| arr[i] = None | |||
| #if ty == "text": | |||
| # if ty == "text": | |||
| # if len(arr) > 128 and uni / len(arr) < 0.1: | |||
| # ty = "keyword" | |||
| return arr, ty | |||
| def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese", callback=None, **kwargs): | |||
| def chunk(filename, binary=None, from_page=0, to_page=10000000000, | |||
| lang="Chinese", callback=None, **kwargs): | |||
| """ | |||
| Excel and csv(txt) format files are supported. | |||
| For csv or txt file, the delimiter between columns is TAB. | |||
| @@ -131,7 +136,12 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese | |||
| if re.search(r"\.xlsx?$", filename, re.IGNORECASE): | |||
| callback(0.1, "Start to parse.") | |||
| excel_parser = Excel() | |||
| dfs = excel_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback) | |||
| dfs = excel_parser( | |||
| filename, | |||
| binary, | |||
| from_page=from_page, | |||
| to_page=to_page, | |||
| callback=callback) | |||
| elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE): | |||
| callback(0.1, "Start to parse.") | |||
| txt = "" | |||
| @@ -149,8 +159,10 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese | |||
| headers = lines[0].split(kwargs.get("delimiter", "\t")) | |||
| rows = [] | |||
| for i, line in enumerate(lines[1:]): | |||
| if i < from_page:continue | |||
| if i >= to_page: break | |||
| if i < from_page: | |||
| continue | |||
| if i >= to_page: | |||
| break | |||
| row = [l for l in line.split(kwargs.get("delimiter", "\t"))] | |||
| if len(row) != len(headers): | |||
| fails.append(str(i)) | |||
| @@ -181,7 +193,13 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese | |||
| del df[n] | |||
| clmns = df.columns.values | |||
| txts = list(copy.deepcopy(clmns)) | |||
| py_clmns = [PY.get_pinyins(re.sub(r"(/.*|([^()]+?)|\([^()]+?\))", "", n), '_')[0] for n in clmns] | |||
| py_clmns = [ | |||
| PY.get_pinyins( | |||
| re.sub( | |||
| r"(/.*|([^()]+?)|\([^()]+?\))", | |||
| "", | |||
| n), | |||
| '_')[0] for n in clmns] | |||
| clmn_tys = [] | |||
| for j in range(len(clmns)): | |||
| cln, ty = column_data_type(df[clmns[j]]) | |||
| @@ -192,7 +210,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese | |||
| clmns_map = [(py_clmns[i].lower() + fieds_map[clmn_tys[i]], clmns[i].replace("_", " ")) | |||
| for i in range(len(clmns))] | |||
| eng = lang.lower() == "english"#is_english(txts) | |||
| eng = lang.lower() == "english" # is_english(txts) | |||
| for ii, row in df.iterrows(): | |||
| d = { | |||
| "docnm_kwd": filename, | |||
| @@ -13,6 +13,8 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from zhipuai import ZhipuAI | |||
| from dashscope import Generation | |||
| from abc import ABC | |||
| from openai import OpenAI | |||
| import openai | |||
| @@ -34,7 +36,8 @@ class GptTurbo(Base): | |||
| self.model_name = model_name | |||
| def chat(self, system, history, gen_conf): | |||
| if system: history.insert(0, {"role": "system", "content": system}) | |||
| if system: | |||
| history.insert(0, {"role": "system", "content": system}) | |||
| try: | |||
| response = self.client.chat.completions.create( | |||
| model=self.model_name, | |||
| @@ -46,16 +49,18 @@ class GptTurbo(Base): | |||
| [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||
| return ans, response.usage.completion_tokens | |||
| except openai.APIError as e: | |||
| return "**ERROR**: "+str(e), 0 | |||
| return "**ERROR**: " + str(e), 0 | |||
| class MoonshotChat(GptTurbo): | |||
| def __init__(self, key, model_name="moonshot-v1-8k"): | |||
| self.client = OpenAI(api_key=key, base_url="https://api.moonshot.cn/v1",) | |||
| self.client = OpenAI( | |||
| api_key=key, base_url="https://api.moonshot.cn/v1",) | |||
| self.model_name = model_name | |||
| def chat(self, system, history, gen_conf): | |||
| if system: history.insert(0, {"role": "system", "content": system}) | |||
| if system: | |||
| history.insert(0, {"role": "system", "content": system}) | |||
| try: | |||
| response = self.client.chat.completions.create( | |||
| model=self.model_name, | |||
| @@ -67,10 +72,9 @@ class MoonshotChat(GptTurbo): | |||
| [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||
| return ans, response.usage.completion_tokens | |||
| except openai.APIError as e: | |||
| return "**ERROR**: "+str(e), 0 | |||
| return "**ERROR**: " + str(e), 0 | |||
| from dashscope import Generation | |||
| class QWenChat(Base): | |||
| def __init__(self, key, model_name=Generation.Models.qwen_turbo): | |||
| import dashscope | |||
| @@ -79,7 +83,8 @@ class QWenChat(Base): | |||
| def chat(self, system, history, gen_conf): | |||
| from http import HTTPStatus | |||
| if system: history.insert(0, {"role": "system", "content": system}) | |||
| if system: | |||
| history.insert(0, {"role": "system", "content": system}) | |||
| response = Generation.call( | |||
| self.model_name, | |||
| messages=history, | |||
| @@ -92,20 +97,21 @@ class QWenChat(Base): | |||
| ans += response.output.choices[0]['message']['content'] | |||
| tk_count += response.usage.output_tokens | |||
| if response.output.choices[0].get("finish_reason", "") == "length": | |||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english( | |||
| [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||
| return ans, tk_count | |||
| return "**ERROR**: " + response.message, tk_count | |||
| from zhipuai import ZhipuAI | |||
| class ZhipuChat(Base): | |||
| def __init__(self, key, model_name="glm-3-turbo"): | |||
| self.client = ZhipuAI(api_key=key) | |||
| self.model_name = model_name | |||
| def chat(self, system, history, gen_conf): | |||
| if system: history.insert(0, {"role": "system", "content": system}) | |||
| if system: | |||
| history.insert(0, {"role": "system", "content": system}) | |||
| try: | |||
| response = self.client.chat.completions.create( | |||
| self.model_name, | |||
| @@ -120,6 +126,7 @@ class ZhipuChat(Base): | |||
| except Exception as e: | |||
| return "**ERROR**: " + str(e), 0 | |||
| class LocalLLM(Base): | |||
| class RPCProxy: | |||
| def __init__(self, host, port): | |||
| @@ -129,14 +136,17 @@ class LocalLLM(Base): | |||
| def __conn(self): | |||
| from multiprocessing.connection import Client | |||
| self._connection = Client((self.host, self.port), authkey=b'infiniflow-token4kevinhu') | |||
| self._connection = Client( | |||
| (self.host, self.port), authkey=b'infiniflow-token4kevinhu') | |||
| def __getattr__(self, name): | |||
| import pickle | |||
| def do_rpc(*args, **kwargs): | |||
| for _ in range(3): | |||
| try: | |||
| self._connection.send(pickle.dumps((name, args, kwargs))) | |||
| self._connection.send( | |||
| pickle.dumps((name, args, kwargs))) | |||
| return pickle.loads(self._connection.recv()) | |||
| except Exception as e: | |||
| self.__conn() | |||
| @@ -148,7 +158,8 @@ class LocalLLM(Base): | |||
| self.client = LocalLLM.RPCProxy("127.0.0.1", 7860) | |||
| def chat(self, system, history, gen_conf): | |||
| if system: history.insert(0, {"role": "system", "content": system}) | |||
| if system: | |||
| history.insert(0, {"role": "system", "content": system}) | |||
| try: | |||
| ans = self.client.chat( | |||
| history, | |||
| @@ -13,6 +13,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from zhipuai import ZhipuAI | |||
| import io | |||
| from abc import ABC | |||
| @@ -57,8 +58,8 @@ class Base(ABC): | |||
| }, | |||
| }, | |||
| { | |||
| "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \ | |||
| "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", | |||
| "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else | |||
| "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", | |||
| }, | |||
| ], | |||
| } | |||
| @@ -92,8 +93,9 @@ class QWenCV(Base): | |||
| def prompt(self, binary): | |||
| # stupid as hell | |||
| tmp_dir = get_project_base_directory("tmp") | |||
| if not os.path.exists(tmp_dir): os.mkdir(tmp_dir) | |||
| path = os.path.join(tmp_dir, "%s.jpg"%get_uuid()) | |||
| if not os.path.exists(tmp_dir): | |||
| os.mkdir(tmp_dir) | |||
| path = os.path.join(tmp_dir, "%s.jpg" % get_uuid()) | |||
| Image.open(io.BytesIO(binary)).save(path) | |||
| return [ | |||
| { | |||
| @@ -103,8 +105,8 @@ class QWenCV(Base): | |||
| "image": f"file://{path}" | |||
| }, | |||
| { | |||
| "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \ | |||
| "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", | |||
| "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else | |||
| "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", | |||
| }, | |||
| ], | |||
| } | |||
| @@ -120,9 +122,6 @@ class QWenCV(Base): | |||
| return response.message, 0 | |||
| from zhipuai import ZhipuAI | |||
| class Zhipu4V(Base): | |||
| def __init__(self, key, model_name="glm-4v", lang="Chinese"): | |||
| self.client = ZhipuAI(api_key=key) | |||
| @@ -13,6 +13,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from zhipuai import ZhipuAI | |||
| import os | |||
| from abc import ABC | |||
| @@ -40,11 +41,11 @@ flag_model = FlagModel(model_dir, | |||
| query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |||
| use_fp16=torch.cuda.is_available()) | |||
| class Base(ABC): | |||
| def __init__(self, key, model_name): | |||
| pass | |||
| def encode(self, texts: list, batch_size=32): | |||
| raise NotImplementedError("Please implement encode method!") | |||
| @@ -67,11 +68,11 @@ class HuEmbedding(Base): | |||
| """ | |||
| self.model = flag_model | |||
| def encode(self, texts: list, batch_size=32): | |||
| texts = [t[:2000] for t in texts] | |||
| token_count = 0 | |||
| for t in texts: token_count += num_tokens_from_string(t) | |||
| for t in texts: | |||
| token_count += num_tokens_from_string(t) | |||
| res = [] | |||
| for i in range(0, len(texts), batch_size): | |||
| res.extend(self.model.encode(texts[i:i + batch_size]).tolist()) | |||
| @@ -90,7 +91,8 @@ class OpenAIEmbed(Base): | |||
| def encode(self, texts: list, batch_size=32): | |||
| res = self.client.embeddings.create(input=texts, | |||
| model=self.model_name) | |||
| return np.array([d.embedding for d in res.data]), res.usage.total_tokens | |||
| return np.array([d.embedding for d in res.data] | |||
| ), res.usage.total_tokens | |||
| def encode_queries(self, text): | |||
| res = self.client.embeddings.create(input=[text], | |||
| @@ -111,7 +113,7 @@ class QWenEmbed(Base): | |||
| for i in range(0, len(texts), batch_size): | |||
| resp = dashscope.TextEmbedding.call( | |||
| model=self.model_name, | |||
| input=texts[i:i+batch_size], | |||
| input=texts[i:i + batch_size], | |||
| text_type="document" | |||
| ) | |||
| embds = [[] for _ in range(len(resp["output"]["embeddings"]))] | |||
| @@ -123,14 +125,14 @@ class QWenEmbed(Base): | |||
| def encode_queries(self, text): | |||
| resp = dashscope.TextEmbedding.call( | |||
| model=self.model_name, | |||
| input=text[:2048], | |||
| text_type="query" | |||
| ) | |||
| return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["total_tokens"] | |||
| model=self.model_name, | |||
| input=text[:2048], | |||
| text_type="query" | |||
| ) | |||
| return np.array(resp["output"]["embeddings"][0] | |||
| ["embedding"]), resp["usage"]["total_tokens"] | |||
| from zhipuai import ZhipuAI | |||
| class ZhipuEmbed(Base): | |||
| def __init__(self, key, model_name="embedding-2"): | |||
| self.client = ZhipuAI(api_key=key) | |||
| @@ -139,9 +141,10 @@ class ZhipuEmbed(Base): | |||
| def encode(self, texts: list, batch_size=32): | |||
| res = self.client.embeddings.create(input=texts, | |||
| model=self.model_name) | |||
| return np.array([d.embedding for d in res.data]), res.usage.total_tokens | |||
| return np.array([d.embedding for d in res.data] | |||
| ), res.usage.total_tokens | |||
| def encode_queries(self, text): | |||
| res = self.client.embeddings.create(input=text, | |||
| model=self.model_name) | |||
| return np.array(res["data"][0]["embedding"]), res.usage.total_tokens | |||
| return np.array(res["data"][0]["embedding"]), res.usage.total_tokens | |||
| @@ -9,7 +9,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer | |||
| class RPCHandler: | |||
| def __init__(self): | |||
| self._functions = { } | |||
| self._functions = {} | |||
| def register_function(self, func): | |||
| self._functions[func.__name__] = func | |||
| @@ -21,12 +21,12 @@ class RPCHandler: | |||
| func_name, args, kwargs = pickle.loads(connection.recv()) | |||
| # Run the RPC and send a response | |||
| try: | |||
| r = self._functions[func_name](*args,**kwargs) | |||
| r = self._functions[func_name](*args, **kwargs) | |||
| connection.send(pickle.dumps(r)) | |||
| except Exception as e: | |||
| connection.send(pickle.dumps(e)) | |||
| except EOFError: | |||
| pass | |||
| pass | |||
| def rpc_server(hdlr, address, authkey): | |||
| @@ -44,11 +44,17 @@ def rpc_server(hdlr, address, authkey): | |||
| models = [] | |||
| tokenizer = None | |||
| def chat(messages, gen_conf): | |||
| global tokenizer | |||
| model = Model() | |||
| try: | |||
| conf = {"max_new_tokens": int(gen_conf.get("max_tokens", 256)), "temperature": float(gen_conf.get("temperature", 0.1))} | |||
| conf = { | |||
| "max_new_tokens": int( | |||
| gen_conf.get( | |||
| "max_tokens", 256)), "temperature": float( | |||
| gen_conf.get( | |||
| "temperature", 0.1))} | |||
| print(messages, conf) | |||
| text = tokenizer.apply_chat_template( | |||
| messages, | |||
| @@ -65,7 +71,8 @@ def chat(messages, gen_conf): | |||
| output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |||
| ] | |||
| return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |||
| return tokenizer.batch_decode( | |||
| generated_ids, skip_special_tokens=True)[0] | |||
| except Exception as e: | |||
| return str(e) | |||
| @@ -75,10 +82,15 @@ def Model(): | |||
| random.seed(time.time()) | |||
| return random.choice(models) | |||
| if __name__ == "__main__": | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument("--model_name", type=str, help="Model name") | |||
| parser.add_argument("--port", default=7860, type=int, help="RPC serving port") | |||
| parser.add_argument( | |||
| "--port", | |||
| default=7860, | |||
| type=int, | |||
| help="RPC serving port") | |||
| args = parser.parse_args() | |||
| handler = RPCHandler() | |||
| @@ -93,4 +105,5 @@ if __name__ == "__main__": | |||
| tokenizer = AutoTokenizer.from_pretrained(args.model_name) | |||
| # Run the server | |||
| rpc_server(handler, ('0.0.0.0', args.port), authkey=b'infiniflow-token4kevinhu') | |||
| rpc_server(handler, ('0.0.0.0', args.port), | |||
| authkey=b'infiniflow-token4kevinhu') | |||
| @@ -372,7 +372,8 @@ class PptChunker(HuChunker): | |||
| tb = shape.table | |||
| rows = [] | |||
| for i in range(1, len(tb.rows)): | |||
| rows.append("; ".join([tb.cell(0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)])) | |||
| rows.append("; ".join([tb.cell( | |||
| 0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)])) | |||
| return "\n".join(rows) | |||
| if shape.has_text_frame: | |||
| @@ -382,7 +383,8 @@ class PptChunker(HuChunker): | |||
| texts = [] | |||
| for p in shape.shapes: | |||
| t = self.__extract(p) | |||
| if t: texts.append(t) | |||
| if t: | |||
| texts.append(t) | |||
| return "\n".join(texts) | |||
| def __call__(self, fnm): | |||
| @@ -395,7 +397,8 @@ class PptChunker(HuChunker): | |||
| texts = [] | |||
| for shape in slide.shapes: | |||
| txt = self.__extract(shape) | |||
| if txt: texts.append(txt) | |||
| if txt: | |||
| texts.append(txt) | |||
| txts.append("\n".join(texts)) | |||
| import aspose.slides as slides | |||
| @@ -404,9 +407,12 @@ class PptChunker(HuChunker): | |||
| with slides.Presentation(BytesIO(fnm)) as presentation: | |||
| for slide in presentation.slides: | |||
| buffered = BytesIO() | |||
| slide.get_thumbnail(0.5, 0.5).save(buffered, drawing.imaging.ImageFormat.jpeg) | |||
| slide.get_thumbnail( | |||
| 0.5, 0.5).save( | |||
| buffered, drawing.imaging.ImageFormat.jpeg) | |||
| imgs.append(buffered.getvalue()) | |||
| assert len(imgs) == len(txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts)) | |||
| assert len(imgs) == len( | |||
| txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts)) | |||
| flds = self.Fields() | |||
| flds.text_chunks = [(txts[i], imgs[i]) for i in range(len(txts))] | |||
| @@ -445,7 +451,8 @@ class TextChunker(HuChunker): | |||
| if isinstance(fnm, str): | |||
| with open(fnm, "r") as f: | |||
| txt = f.read() | |||
| else: txt = fnm.decode("utf-8") | |||
| else: | |||
| txt = fnm.decode("utf-8") | |||
| flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)] | |||
| flds.table_chunks = [] | |||
| return flds | |||
| @@ -149,7 +149,8 @@ class EsQueryer: | |||
| atks = toDict(atks) | |||
| btkss = [toDict(tks) for tks in btkss] | |||
| tksim = [self.similarity(atks, btks) for btks in btkss] | |||
| return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0] | |||
| return np.array(sims[0]) * vtweight + \ | |||
| np.array(tksim) * tkweight, tksim, sims[0] | |||
| def similarity(self, qtwt, dtwt): | |||
| if isinstance(dtwt, type("")): | |||
| @@ -159,11 +160,11 @@ class EsQueryer: | |||
| s = 1e-9 | |||
| for k, v in qtwt.items(): | |||
| if k in dtwt: | |||
| s += v# * dtwt[k] | |||
| s += v # * dtwt[k] | |||
| q = 1e-9 | |||
| for k, v in qtwt.items(): | |||
| q += v #* v | |||
| q += v # * v | |||
| #d = 1e-9 | |||
| #for k, v in dtwt.items(): | |||
| # for k, v in dtwt.items(): | |||
| # d += v * v | |||
| return s / q #math.sqrt(q) / math.sqrt(d) | |||
| return s / q # math.sqrt(q) / math.sqrt(d) | |||
| @@ -80,14 +80,18 @@ class Dealer: | |||
| if not req.get("sort"): | |||
| s = s.sort( | |||
| {"create_time": {"order": "desc", "unmapped_type": "date"}}, | |||
| {"create_timestamp_flt": {"order": "desc", "unmapped_type": "float"}} | |||
| {"create_timestamp_flt": { | |||
| "order": "desc", "unmapped_type": "float"}} | |||
| ) | |||
| else: | |||
| s = s.sort( | |||
| {"page_num_int": {"order": "asc", "unmapped_type": "float", "mode": "avg", "numeric_type": "double"}}, | |||
| {"top_int": {"order": "asc", "unmapped_type": "float", "mode": "avg", "numeric_type": "double"}}, | |||
| {"page_num_int": {"order": "asc", "unmapped_type": "float", | |||
| "mode": "avg", "numeric_type": "double"}}, | |||
| {"top_int": {"order": "asc", "unmapped_type": "float", | |||
| "mode": "avg", "numeric_type": "double"}}, | |||
| {"create_time": {"order": "desc", "unmapped_type": "date"}}, | |||
| {"create_timestamp_flt": {"order": "desc", "unmapped_type": "float"}} | |||
| {"create_timestamp_flt": { | |||
| "order": "desc", "unmapped_type": "float"}} | |||
| ) | |||
| if qst: | |||
| @@ -180,11 +184,13 @@ class Dealer: | |||
| m = {n: d.get(n) for n in flds if d.get(n) is not None} | |||
| for n, v in m.items(): | |||
| if isinstance(v, type([])): | |||
| m[n] = "\t".join([str(vv) if not isinstance(vv, list) else "\t".join([str(vvv) for vvv in vv]) for vv in v]) | |||
| m[n] = "\t".join([str(vv) if not isinstance( | |||
| vv, list) else "\t".join([str(vvv) for vvv in vv]) for vv in v]) | |||
| continue | |||
| if not isinstance(v, type("")): | |||
| m[n] = str(m[n]) | |||
| if n.find("tks")>0: m[n] = rmSpace(m[n]) | |||
| if n.find("tks") > 0: | |||
| m[n] = rmSpace(m[n]) | |||
| if m: | |||
| res[d["id"]] = m | |||
| @@ -205,12 +211,16 @@ class Dealer: | |||
| if pieces[i] == "```": | |||
| st = i | |||
| i += 1 | |||
| while i<len(pieces) and pieces[i] != "```": | |||
| while i < len(pieces) and pieces[i] != "```": | |||
| i += 1 | |||
| if i < len(pieces): i += 1 | |||
| pieces_.append("".join(pieces[st: i])+"\n") | |||
| if i < len(pieces): | |||
| i += 1 | |||
| pieces_.append("".join(pieces[st: i]) + "\n") | |||
| else: | |||
| pieces_.extend(re.split(r"([^\|][;。?!!\n]|[a-z][.?;!][ \n])", pieces[i])) | |||
| pieces_.extend( | |||
| re.split( | |||
| r"([^\|][;。?!!\n]|[a-z][.?;!][ \n])", | |||
| pieces[i])) | |||
| i += 1 | |||
| pieces = pieces_ | |||
| else: | |||
| @@ -234,7 +244,8 @@ class Dealer: | |||
| assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format( | |||
| len(ans_v[0]), len(chunk_v[0])) | |||
| chunks_tks = [huqie.qie(self.qryr.rmWWW(ck)).split(" ") for ck in chunks] | |||
| chunks_tks = [huqie.qie(self.qryr.rmWWW(ck)).split(" ") | |||
| for ck in chunks] | |||
| cites = {} | |||
| for i, a in enumerate(pieces_): | |||
| sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i], | |||
| @@ -258,9 +269,11 @@ class Dealer: | |||
| continue | |||
| if i not in cites: | |||
| continue | |||
| for c in cites[i]: assert int(c) < len(chunk_v) | |||
| for c in cites[i]: | |||
| if c in seted:continue | |||
| assert int(c) < len(chunk_v) | |||
| for c in cites[i]: | |||
| if c in seted: | |||
| continue | |||
| res += f" ##{c}$$" | |||
| seted.add(c) | |||
| @@ -343,7 +356,11 @@ class Dealer: | |||
| if dnm not in ranks["doc_aggs"]: | |||
| ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0} | |||
| ranks["doc_aggs"][dnm]["count"] += 1 | |||
| ranks["doc_aggs"] = [{"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} for k,v in sorted(ranks["doc_aggs"].items(), key=lambda x:x[1]["count"]*-1)] | |||
| ranks["doc_aggs"] = [{"doc_name": k, | |||
| "doc_id": v["doc_id"], | |||
| "count": v["count"]} for k, | |||
| v in sorted(ranks["doc_aggs"].items(), | |||
| key=lambda x:x[1]["count"] * -1)] | |||
| return ranks | |||
| @@ -354,10 +371,17 @@ class Dealer: | |||
| replaces = [] | |||
| for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql): | |||
| fld, v = r.group(1), r.group(3) | |||
| match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(fld, huqie.qieqie(huqie.qie(v))) | |||
| replaces.append(("{}{}'{}'".format(r.group(1), r.group(2), r.group(3)), match)) | |||
| for p, r in replaces: sql = sql.replace(p, r, 1) | |||
| match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format( | |||
| fld, huqie.qieqie(huqie.qie(v))) | |||
| replaces.append( | |||
| ("{}{}'{}'".format( | |||
| r.group(1), | |||
| r.group(2), | |||
| r.group(3)), | |||
| match)) | |||
| for p, r in replaces: | |||
| sql = sql.replace(p, r, 1) | |||
| chat_logger.info(f"To es: {sql}") | |||
| try: | |||
| @@ -366,4 +390,3 @@ class Dealer: | |||
| except Exception as e: | |||
| chat_logger.error(f"SQL failure: {sql} =>" + str(e)) | |||
| return {"error": str(e)} | |||
| @@ -150,8 +150,10 @@ class Dealer: | |||
| return 6 | |||
| def ner(t): | |||
| if re.match(r"[0-9,.]{2,}$", t): return 2 | |||
| if re.match(r"[a-z]{1,2}$", t): return 0.01 | |||
| if re.match(r"[0-9,.]{2,}$", t): | |||
| return 2 | |||
| if re.match(r"[a-z]{1,2}$", t): | |||
| return 0.01 | |||
| if not self.ne or t not in self.ne: | |||
| return 1 | |||
| m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3, | |||
| @@ -14,7 +14,7 @@ | |||
| # limitations under the License. | |||
| # | |||
| import os | |||
| from api.utils import get_base_config,decrypt_database_config | |||
| from api.utils import get_base_config, decrypt_database_config | |||
| from api.utils.file_utils import get_project_base_directory | |||
| from api.utils.log_utils import LoggerFactory, getLogger | |||
| @@ -28,7 +28,11 @@ MINIO = decrypt_database_config(name="minio") | |||
| DOC_MAXIMUM_SIZE = 128 * 1024 * 1024 | |||
| # Logger | |||
| LoggerFactory.set_directory(os.path.join(get_project_base_directory(), "logs", "rag")) | |||
| LoggerFactory.set_directory( | |||
| os.path.join( | |||
| get_project_base_directory(), | |||
| "logs", | |||
| "rag")) | |||
| # {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0} | |||
| LoggerFactory.LEVEL = 10 | |||
| @@ -37,4 +41,3 @@ minio_logger = getLogger("minio") | |||
| cron_logger = getLogger("cron_logger") | |||
| chunk_logger = getLogger("chunk_logger") | |||
| database_logger = getLogger("database") | |||
| @@ -47,7 +47,7 @@ def collect(tm): | |||
| def set_dispatching(docid): | |||
| try: | |||
| DocumentService.update_by_id( | |||
| docid, {"progress": random.random()*1 / 100., | |||
| docid, {"progress": random.random() * 1 / 100., | |||
| "progress_msg": "Task dispatched...", | |||
| "process_begin_at": get_format_time() | |||
| }) | |||
| @@ -56,7 +56,10 @@ def set_dispatching(docid): | |||
| def dispatch(): | |||
| tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"broker.tm") | |||
| tm_fnm = os.path.join( | |||
| get_project_base_directory(), | |||
| "rag/res", | |||
| f"broker.tm") | |||
| tm = findMaxTm(tm_fnm) | |||
| rows = collect(tm) | |||
| if len(rows) == 0: | |||
| @@ -82,17 +85,22 @@ def dispatch(): | |||
| tsks = [] | |||
| if r["type"] == FileType.PDF.value: | |||
| do_layout = r["parser_config"].get("layout_recognize", True) | |||
| pages = PdfParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"])) | |||
| pages = PdfParser.total_page_number( | |||
| r["name"], MINIO.get(r["kb_id"], r["location"])) | |||
| page_size = r["parser_config"].get("task_page_size", 12) | |||
| if r["parser_id"] == "paper": page_size = r["parser_config"].get("task_page_size", 22) | |||
| if r["parser_id"] == "one": page_size = 1000000000 | |||
| if not do_layout: page_size = 1000000000 | |||
| if r["parser_id"] == "paper": | |||
| page_size = r["parser_config"].get("task_page_size", 22) | |||
| if r["parser_id"] == "one": | |||
| page_size = 1000000000 | |||
| if not do_layout: | |||
| page_size = 1000000000 | |||
| page_ranges = r["parser_config"].get("pages") | |||
| if not page_ranges: page_ranges = [(1, 100000)] | |||
| for s,e in page_ranges: | |||
| if not page_ranges: | |||
| page_ranges = [(1, 100000)] | |||
| for s, e in page_ranges: | |||
| s -= 1 | |||
| s = max(0, s) | |||
| e = min(e-1, pages) | |||
| e = min(e - 1, pages) | |||
| for p in range(s, e, page_size): | |||
| task = new_task() | |||
| task["from_page"] = p | |||
| @@ -100,12 +108,14 @@ def dispatch(): | |||
| tsks.append(task) | |||
| elif r["parser_id"] == "table": | |||
| rn = HuExcelParser.row_number(r["name"], MINIO.get(r["kb_id"], r["location"])) | |||
| for i in range(0, rn, 3000): | |||
| task = new_task() | |||
| task["from_page"] = i | |||
| task["to_page"] = min(i + 3000, rn) | |||
| tsks.append(task) | |||
| rn = HuExcelParser.row_number( | |||
| r["name"], MINIO.get( | |||
| r["kb_id"], r["location"])) | |||
| for i in range(0, rn, 3000): | |||
| task = new_task() | |||
| task["from_page"] = i | |||
| task["to_page"] = min(i + 3000, rn) | |||
| tsks.append(task) | |||
| else: | |||
| tsks.append(new_task()) | |||
| @@ -120,27 +130,37 @@ def update_progress(): | |||
| for d in docs: | |||
| try: | |||
| tsks = TaskService.query(doc_id=d["id"], order_by=Task.create_time) | |||
| if not tsks:continue | |||
| if not tsks: | |||
| continue | |||
| msg = [] | |||
| prg = 0 | |||
| finished = True | |||
| bad = 0 | |||
| status = TaskStatus.RUNNING.value | |||
| for t in tsks: | |||
| if 0 <= t.progress < 1: finished = False | |||
| if 0 <= t.progress < 1: | |||
| finished = False | |||
| prg += t.progress if t.progress >= 0 else 0 | |||
| msg.append(t.progress_msg) | |||
| if t.progress == -1: bad += 1 | |||
| if t.progress == -1: | |||
| bad += 1 | |||
| prg /= len(tsks) | |||
| if finished and bad: | |||
| prg = -1 | |||
| status = TaskStatus.FAIL.value | |||
| elif finished: status = TaskStatus.DONE.value | |||
| elif finished: | |||
| status = TaskStatus.DONE.value | |||
| msg = "\n".join(msg) | |||
| info = {"process_duation": datetime.timestamp(datetime.now())-d["process_begin_at"].timestamp(), "run": status} | |||
| if prg !=0 : info["progress"] = prg | |||
| if msg: info["progress_msg"] = msg | |||
| info = { | |||
| "process_duation": datetime.timestamp( | |||
| datetime.now()) - | |||
| d["process_begin_at"].timestamp(), | |||
| "run": status} | |||
| if prg != 0: | |||
| info["progress"] = prg | |||
| if msg: | |||
| info["progress_msg"] = msg | |||
| DocumentService.update_by_id(d["id"], info) | |||
| except Exception as e: | |||
| cron_logger.error("fetch task exception:" + str(e)) | |||
| @@ -67,7 +67,7 @@ FACTORY = { | |||
| def set_progress(task_id, from_page=0, to_page=-1, | |||
| prog=None, msg="Processing..."): | |||
| if prog is not None and prog < 0: | |||
| msg = "[ERROR]"+msg | |||
| msg = "[ERROR]" + msg | |||
| cancel = TaskService.do_cancel(task_id) | |||
| if cancel: | |||
| msg += " [Canceled]" | |||
| @@ -188,11 +188,13 @@ def embedding(docs, mdl, parser_config={}, callback=None): | |||
| cnts_ = np.array([]) | |||
| for i in range(0, len(cnts), batch_size): | |||
| vts, c = mdl.encode(cnts[i: i+batch_size]) | |||
| if len(cnts_) == 0: cnts_ = vts | |||
| else: cnts_ = np.concatenate((cnts_, vts), axis=0) | |||
| vts, c = mdl.encode(cnts[i: i + batch_size]) | |||
| if len(cnts_) == 0: | |||
| cnts_ = vts | |||
| else: | |||
| cnts_ = np.concatenate((cnts_, vts), axis=0) | |||
| tk_count += c | |||
| callback(prog=0.7+0.2*(i+1)/len(cnts), msg="") | |||
| callback(prog=0.7 + 0.2 * (i + 1) / len(cnts), msg="") | |||
| cnts = cnts_ | |||
| title_w = float(parser_config.get("filename_embd_weight", 0.1)) | |||
| @@ -234,7 +236,9 @@ def main(comm, mod): | |||
| continue | |||
| # TODO: exception handler | |||
| ## set_progress(r["did"], -1, "ERROR: ") | |||
| callback(msg="Finished slicing files(%d). Start to embedding the content."%len(cks)) | |||
| callback( | |||
| msg="Finished slicing files(%d). Start to embedding the content." % | |||
| len(cks)) | |||
| try: | |||
| tk_count = embedding(cks, embd_mdl, r["parser_config"], callback) | |||
| except Exception as e: | |||
| @@ -249,7 +253,7 @@ def main(comm, mod): | |||
| if es_r: | |||
| callback(-1, "Index failure!") | |||
| ELASTICSEARCH.deleteByQuery( | |||
| Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"])) | |||
| Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"])) | |||
| cron_logger.error(str(es_r)) | |||
| else: | |||
| if TaskService.do_cancel(r["id"]): | |||