### What problem does this PR solve? ### Type of change - [x] Refactoringtags/v0.3.1
| @@ -58,7 +58,8 @@ def upload(): | |||
| if not e: | |||
| return get_data_error_result( | |||
| retmsg="Can't find this knowledgebase!") | |||
| if DocumentService.get_doc_count(kb.tenant_id) >= int(os.environ.get('MAX_FILE_NUM_PER_USER', 8192)): | |||
| MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0)) | |||
| if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER: | |||
| return get_data_error_result( | |||
| retmsg="Exceed the maximum file number of a free user!") | |||
| @@ -28,7 +28,7 @@ from rag.llm import EmbeddingModel, ChatModel | |||
| def factories(): | |||
| try: | |||
| fac = LLMFactoriesService.get_all() | |||
| return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["QAnything", "FastEmbed"]]) | |||
| return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed"]]) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -174,7 +174,7 @@ def list(): | |||
| llms = [m.to_dict() | |||
| for m in llms if m.status == StatusEnum.VALID.value] | |||
| for m in llms: | |||
| m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["QAnything","FastEmbed"] | |||
| m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed"] | |||
| llm_set = set([m["llm_name"] for m in llms]) | |||
| for o in objs: | |||
| @@ -697,7 +697,7 @@ class Dialog(DataBaseModel): | |||
| null=True, | |||
| default="Chinese", | |||
| help_text="English|Chinese") | |||
| llm_id = CharField(max_length=32, null=False, help_text="default llm ID") | |||
| llm_id = CharField(max_length=128, null=False, help_text="default llm ID") | |||
| llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7, | |||
| "presence_penalty": 0.4, "max_tokens": 215}) | |||
| prompt_type = CharField( | |||
| @@ -120,7 +120,7 @@ factory_infos = [{ | |||
| "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", | |||
| "status": "1", | |||
| },{ | |||
| "name": "QAnything", | |||
| "name": "Youdao", | |||
| "logo": "", | |||
| "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", | |||
| "status": "1", | |||
| @@ -323,7 +323,7 @@ def init_llm_factory(): | |||
| "max_tokens": 2147483648, | |||
| "model_type": LLMType.EMBEDDING.value | |||
| }, | |||
| # ------------------------ QAnything ----------------------- | |||
| # ------------------------ Youdao ----------------------- | |||
| { | |||
| "fid": factory_infos[7]["name"], | |||
| "llm_name": "maidalun1020/bce-embedding-base_v1", | |||
| @@ -347,7 +347,9 @@ def init_llm_factory(): | |||
| LLMService.filter_delete([LLM.fid == "Local"]) | |||
| LLMService.filter_delete([LLM.fid == "Moonshot", LLM.llm_name == "flag-embedding"]) | |||
| TenantLLMService.filter_delete([TenantLLM.llm_factory == "Moonshot", TenantLLM.llm_name == "flag-embedding"]) | |||
| LLMFactoriesService.filter_update([LLMFactoriesService.model.name == "QAnything"], {"name": "Youdao"}) | |||
| LLMService.filter_update([LLMService.model.fid == "QAnything"], {"fid": "Youdao"}) | |||
| TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"}) | |||
| """ | |||
| drop table llm; | |||
| drop table llm_factories; | |||
| @@ -81,7 +81,7 @@ class TenantLLMService(CommonService): | |||
| if not model_config: | |||
| if llm_type == LLMType.EMBEDDING.value: | |||
| llm = LLMService.query(llm_name=llm_name) | |||
| if llm and llm[0].fid in ["QAnything", "FastEmbed"]: | |||
| if llm and llm[0].fid in ["Youdao", "FastEmbed"]: | |||
| model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""} | |||
| if not model_config: | |||
| if llm_name == "flag-embedding": | |||
| @@ -21,6 +21,7 @@ from api.db import StatusEnum, FileType, TaskStatus | |||
| from api.db.db_models import Task, Document, Knowledgebase, Tenant | |||
| from api.db.services.common_service import CommonService | |||
| from api.db.services.document_service import DocumentService | |||
| from api.utils import current_timestamp | |||
| class TaskService(CommonService): | |||
| @@ -70,6 +71,25 @@ class TaskService(CommonService): | |||
| cls.model.id == docs[0]["id"]).execute() | |||
| return docs | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_ongoing_doc_name(cls): | |||
| with DB.lock("get_task", -1): | |||
| docs = cls.model.select(*[Document.kb_id, Document.location]) \ | |||
| .join(Document, on=(cls.model.doc_id == Document.id)) \ | |||
| .where( | |||
| Document.status == StatusEnum.VALID.value, | |||
| Document.run == TaskStatus.RUNNING.value, | |||
| ~(Document.type == FileType.VIRTUAL.value), | |||
| cls.model.progress >= 0, | |||
| cls.model.progress < 1, | |||
| cls.model.create_time >= current_timestamp() - 180000 | |||
| ) | |||
| docs = list(docs.dicts()) | |||
| if not docs: return [] | |||
| return list(set([(d["kb_id"], d["location"]) for d in docs])) | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def do_cancel(cls, id): | |||
| @@ -37,8 +37,8 @@ class HuParser: | |||
| self.updown_cnt_mdl.set_param({"device": "cuda"}) | |||
| try: | |||
| model_dir = os.path.join( | |||
| get_project_base_directory(), | |||
| "rag/res/deepdoc") | |||
| get_project_base_directory(), | |||
| "rag/res/deepdoc") | |||
| self.updown_cnt_mdl.load_model(os.path.join( | |||
| model_dir, "updown_concat_xgb.model")) | |||
| except Exception as e: | |||
| @@ -49,7 +49,6 @@ class HuParser: | |||
| self.updown_cnt_mdl.load_model(os.path.join( | |||
| model_dir, "updown_concat_xgb.model")) | |||
| self.page_from = 0 | |||
| """ | |||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | |||
| @@ -76,7 +75,7 @@ class HuParser: | |||
| def _y_dis( | |||
| self, a, b): | |||
| return ( | |||
| b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2 | |||
| b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2 | |||
| def _match_proj(self, b): | |||
| proj_patt = [ | |||
| @@ -99,9 +98,9 @@ class HuParser: | |||
| tks_down = huqie.qie(down["text"][:LEN]).split(" ") | |||
| tks_up = huqie.qie(up["text"][-LEN:]).split(" ") | |||
| tks_all = up["text"][-LEN:].strip() \ | |||
| + (" " if re.match(r"[a-zA-Z0-9]+", | |||
| up["text"][-1] + down["text"][0]) else "") \ | |||
| + down["text"][:LEN].strip() | |||
| + (" " if re.match(r"[a-zA-Z0-9]+", | |||
| up["text"][-1] + down["text"][0]) else "") \ | |||
| + down["text"][:LEN].strip() | |||
| tks_all = huqie.qie(tks_all).split(" ") | |||
| fea = [ | |||
| up.get("R", -1) == down.get("R", -1), | |||
| @@ -123,7 +122,7 @@ class HuParser: | |||
| True if re.search(r"[,,][^。.]+$", up["text"]) else False, | |||
| True if re.search(r"[,,][^。.]+$", up["text"]) else False, | |||
| True if re.search(r"[\((][^\))]+$", up["text"]) | |||
| and re.search(r"[\))]", down["text"]) else False, | |||
| and re.search(r"[\))]", down["text"]) else False, | |||
| self._match_proj(down), | |||
| True if re.match(r"[A-Z]", down["text"]) else False, | |||
| True if re.match(r"[A-Z]", up["text"][-1]) else False, | |||
| @@ -185,7 +184,7 @@ class HuParser: | |||
| continue | |||
| for tb in tbls: # for table | |||
| left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \ | |||
| tb["x1"] + MARGIN, tb["bottom"] + MARGIN | |||
| tb["x1"] + MARGIN, tb["bottom"] + MARGIN | |||
| left *= ZM | |||
| top *= ZM | |||
| right *= ZM | |||
| @@ -297,7 +296,7 @@ class HuParser: | |||
| for b in bxs: | |||
| if not b["text"]: | |||
| left, right, top, bott = b["x0"] * ZM, b["x1"] * \ | |||
| ZM, b["top"] * ZM, b["bottom"] * ZM | |||
| ZM, b["top"] * ZM, b["bottom"] * ZM | |||
| b["text"] = self.ocr.recognize(np.array(img), | |||
| np.array([[left, top], [right, top], [right, bott], [left, bott]], | |||
| dtype=np.float32)) | |||
| @@ -622,7 +621,7 @@ class HuParser: | |||
| i += 1 | |||
| continue | |||
| lout_no = str(self.boxes[i]["page_number"]) + \ | |||
| "-" + str(self.boxes[i]["layoutno"]) | |||
| "-" + str(self.boxes[i]["layoutno"]) | |||
| if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", | |||
| "title", | |||
| "figure caption", | |||
| @@ -975,6 +974,7 @@ class HuParser: | |||
| self.outlines.append((a["/Title"], depth)) | |||
| continue | |||
| dfs(a, depth + 1) | |||
| dfs(outlines, 0) | |||
| except Exception as e: | |||
| logging.warning(f"Outlines exception: {e}") | |||
| @@ -984,7 +984,7 @@ class HuParser: | |||
| logging.info("Images converted.") | |||
| self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join( | |||
| random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in | |||
| range(len(self.page_chars))] | |||
| range(len(self.page_chars))] | |||
| if sum([1 if e else 0 for e in self.is_english]) > len( | |||
| self.page_images) / 2: | |||
| self.is_english = True | |||
| @@ -1012,9 +1012,9 @@ class HuParser: | |||
| j += 1 | |||
| self.__ocr(i + 1, img, chars, zoomin) | |||
| #if callback: | |||
| # callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="") | |||
| #print("OCR:", timer()-st) | |||
| if callback and i % 6 == 5: | |||
| callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="") | |||
| # print("OCR:", timer()-st) | |||
| if not self.is_english and not any( | |||
| [c for c in self.page_chars]) and self.boxes: | |||
| @@ -1050,7 +1050,7 @@ class HuParser: | |||
| left, right, top, bottom = float(left), float( | |||
| right), float(top), float(bottom) | |||
| poss.append(([int(p) - 1 for p in pn.split("-")], | |||
| left, right, top, bottom)) | |||
| left, right, top, bottom)) | |||
| if not poss: | |||
| if need_position: | |||
| return None, None | |||
| @@ -1076,7 +1076,7 @@ class HuParser: | |||
| self.page_images[pns[0]].crop((left * ZM, top * ZM, | |||
| right * | |||
| ZM, min( | |||
| bottom, self.page_images[pns[0]].size[1]) | |||
| bottom, self.page_images[pns[0]].size[1]) | |||
| )) | |||
| ) | |||
| if 0 < ii < len(poss) - 1: | |||
| @@ -25,7 +25,7 @@ EmbeddingModel = { | |||
| "Tongyi-Qianwen": HuEmbedding, #QWenEmbed, | |||
| "ZHIPU-AI": ZhipuEmbed, | |||
| "FastEmbed": FastEmbed, | |||
| "QAnything": QAnythingEmbed | |||
| "Youdao": YoudaoEmbed | |||
| } | |||
| @@ -229,19 +229,19 @@ class XinferenceEmbed(Base): | |||
| return np.array(res.data[0].embedding), res.usage.total_tokens | |||
| class QAnythingEmbed(Base): | |||
| class YoudaoEmbed(Base): | |||
| _client = None | |||
| def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs): | |||
| from BCEmbedding import EmbeddingModel as qanthing | |||
| if not QAnythingEmbed._client: | |||
| if not YoudaoEmbed._client: | |||
| try: | |||
| print("LOADING BCE...") | |||
| QAnythingEmbed._client = qanthing(model_name_or_path=os.path.join( | |||
| YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join( | |||
| get_project_base_directory(), | |||
| "rag/res/bce-embedding-base_v1")) | |||
| except Exception as e: | |||
| QAnythingEmbed._client = qanthing( | |||
| YoudaoEmbed._client = qanthing( | |||
| model_name_or_path=model_name.replace( | |||
| "maidalun1020", "InfiniFlow")) | |||
| @@ -251,10 +251,10 @@ class QAnythingEmbed(Base): | |||
| for t in texts: | |||
| token_count += num_tokens_from_string(t) | |||
| for i in range(0, len(texts), batch_size): | |||
| embds = QAnythingEmbed._client.encode(texts[i:i + batch_size]) | |||
| embds = YoudaoEmbed._client.encode(texts[i:i + batch_size]) | |||
| res.extend(embds) | |||
| return np.array(res), token_count | |||
| def encode_queries(self, text): | |||
| embds = QAnythingEmbed._client.encode([text]) | |||
| embds = YoudaoEmbed._client.encode([text]) | |||
| return np.array(embds[0]), num_tokens_from_string(text) | |||
| @@ -0,0 +1,43 @@ | |||
| import random | |||
| import time | |||
| import traceback | |||
| from api.db.db_models import close_connection | |||
| from api.db.services.task_service import TaskService | |||
| from rag.utils import MINIO | |||
| from rag.utils.redis_conn import REDIS_CONN | |||
| def collect(): | |||
| doc_locations = TaskService.get_ongoing_doc_name() | |||
| #print(tasks) | |||
| if len(doc_locations) == 0: | |||
| time.sleep(1) | |||
| return | |||
| return doc_locations | |||
| def main(): | |||
| locations = collect() | |||
| if not locations:return | |||
| print("TASKS:", len(locations)) | |||
| for kb_id, loc in locations: | |||
| try: | |||
| if REDIS_CONN.is_alive(): | |||
| try: | |||
| key = "{}/{}".format(kb_id, loc) | |||
| if REDIS_CONN.exist(key):continue | |||
| file_bin = MINIO.get(kb_id, loc) | |||
| REDIS_CONN.transaction(key, file_bin, 12 * 60) | |||
| print("CACHE:", loc) | |||
| except Exception as e: | |||
| traceback.print_stack(e) | |||
| except Exception as e: | |||
| traceback.print_stack(e) | |||
| if __name__ == "__main__": | |||
| while True: | |||
| main() | |||
| close_connection() | |||
| time.sleep(1) | |||
| @@ -167,7 +167,7 @@ def update_progress(): | |||
| info = { | |||
| "process_duation": datetime.timestamp( | |||
| datetime.now()) - | |||
| d["process_begin_at"].timestamp(), | |||
| d["process_begin_at"].timestamp(), | |||
| "run": status} | |||
| if prg != 0: | |||
| info["progress"] = prg | |||
| @@ -107,8 +107,14 @@ def get_minio_binary(bucket, name): | |||
| global MINIO | |||
| if REDIS_CONN.is_alive(): | |||
| try: | |||
| for _ in range(30): | |||
| if REDIS_CONN.exist("{}/{}".format(bucket, name)): | |||
| time.sleep(1) | |||
| break | |||
| time.sleep(1) | |||
| r = REDIS_CONN.get("{}/{}".format(bucket, name)) | |||
| if r: return r | |||
| cron_logger.warning("Cache missing: {}".format(name)) | |||
| except Exception as e: | |||
| cron_logger.warning("Get redis[EXCEPTION]:" + str(e)) | |||
| return MINIO.get(bucket, name) | |||
| @@ -56,7 +56,6 @@ class HuMinio(object): | |||
| except Exception as e: | |||
| minio_logger.error(f"Fail rm {bucket}/{fnm}: " + str(e)) | |||
| def get(self, bucket, fnm): | |||
| for _ in range(1): | |||
| try: | |||
| @@ -25,6 +25,14 @@ class RedisDB: | |||
| def is_alive(self): | |||
| return self.REDIS is not None | |||
| def exist(self, k): | |||
| if not self.REDIS: return | |||
| try: | |||
| return self.REDIS.exists(k) | |||
| except Exception as e: | |||
| logging.warning("[EXCEPTION]exist" + str(k) + "||" + str(e)) | |||
| self.__open__() | |||
| def get(self, k): | |||
| if not self.REDIS: return | |||
| try: | |||
| @@ -51,5 +59,16 @@ class RedisDB: | |||
| self.__open__() | |||
| return False | |||
| def transaction(self, key, value, exp=3600): | |||
| try: | |||
| pipeline = self.REDIS.pipeline(transaction=True) | |||
| pipeline.set(key, value, exp, nx=True) | |||
| pipeline.execute() | |||
| return True | |||
| except Exception as e: | |||
| logging.warning("[EXCEPTION]set" + str(key) + "||" + str(e)) | |||
| self.__open__() | |||
| return False | |||
| REDIS_CONN = RedisDB() | |||