### What problem does this PR solve? ### Type of change - [x] Refactoringtags/v0.3.1
| if not e: | if not e: | ||||
| return get_data_error_result( | return get_data_error_result( | ||||
| retmsg="Can't find this knowledgebase!") | retmsg="Can't find this knowledgebase!") | ||||
| if DocumentService.get_doc_count(kb.tenant_id) >= int(os.environ.get('MAX_FILE_NUM_PER_USER', 8192)): | |||||
| MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0)) | |||||
| if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER: | |||||
| return get_data_error_result( | return get_data_error_result( | ||||
| retmsg="Exceed the maximum file number of a free user!") | retmsg="Exceed the maximum file number of a free user!") | ||||
| def factories(): | def factories(): | ||||
| try: | try: | ||||
| fac = LLMFactoriesService.get_all() | fac = LLMFactoriesService.get_all() | ||||
| return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["QAnything", "FastEmbed"]]) | |||||
| return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed"]]) | |||||
| except Exception as e: | except Exception as e: | ||||
| return server_error_response(e) | return server_error_response(e) | ||||
| llms = [m.to_dict() | llms = [m.to_dict() | ||||
| for m in llms if m.status == StatusEnum.VALID.value] | for m in llms if m.status == StatusEnum.VALID.value] | ||||
| for m in llms: | for m in llms: | ||||
| m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["QAnything","FastEmbed"] | |||||
| m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed"] | |||||
| llm_set = set([m["llm_name"] for m in llms]) | llm_set = set([m["llm_name"] for m in llms]) | ||||
| for o in objs: | for o in objs: |
| null=True, | null=True, | ||||
| default="Chinese", | default="Chinese", | ||||
| help_text="English|Chinese") | help_text="English|Chinese") | ||||
| llm_id = CharField(max_length=32, null=False, help_text="default llm ID") | |||||
| llm_id = CharField(max_length=128, null=False, help_text="default llm ID") | |||||
| llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7, | llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7, | ||||
| "presence_penalty": 0.4, "max_tokens": 215}) | "presence_penalty": 0.4, "max_tokens": 215}) | ||||
| prompt_type = CharField( | prompt_type = CharField( |
| "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", | "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", | ||||
| "status": "1", | "status": "1", | ||||
| },{ | },{ | ||||
| "name": "QAnything", | |||||
| "name": "Youdao", | |||||
| "logo": "", | "logo": "", | ||||
| "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", | "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", | ||||
| "status": "1", | "status": "1", | ||||
| "max_tokens": 2147483648, | "max_tokens": 2147483648, | ||||
| "model_type": LLMType.EMBEDDING.value | "model_type": LLMType.EMBEDDING.value | ||||
| }, | }, | ||||
| # ------------------------ QAnything ----------------------- | |||||
| # ------------------------ Youdao ----------------------- | |||||
| { | { | ||||
| "fid": factory_infos[7]["name"], | "fid": factory_infos[7]["name"], | ||||
| "llm_name": "maidalun1020/bce-embedding-base_v1", | "llm_name": "maidalun1020/bce-embedding-base_v1", | ||||
| LLMService.filter_delete([LLM.fid == "Local"]) | LLMService.filter_delete([LLM.fid == "Local"]) | ||||
| LLMService.filter_delete([LLM.fid == "Moonshot", LLM.llm_name == "flag-embedding"]) | LLMService.filter_delete([LLM.fid == "Moonshot", LLM.llm_name == "flag-embedding"]) | ||||
| TenantLLMService.filter_delete([TenantLLM.llm_factory == "Moonshot", TenantLLM.llm_name == "flag-embedding"]) | TenantLLMService.filter_delete([TenantLLM.llm_factory == "Moonshot", TenantLLM.llm_name == "flag-embedding"]) | ||||
| LLMFactoriesService.filter_update([LLMFactoriesService.model.name == "QAnything"], {"name": "Youdao"}) | |||||
| LLMService.filter_update([LLMService.model.fid == "QAnything"], {"fid": "Youdao"}) | |||||
| TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"}) | |||||
| """ | """ | ||||
| drop table llm; | drop table llm; | ||||
| drop table llm_factories; | drop table llm_factories; |
| if not model_config: | if not model_config: | ||||
| if llm_type == LLMType.EMBEDDING.value: | if llm_type == LLMType.EMBEDDING.value: | ||||
| llm = LLMService.query(llm_name=llm_name) | llm = LLMService.query(llm_name=llm_name) | ||||
| if llm and llm[0].fid in ["QAnything", "FastEmbed"]: | |||||
| if llm and llm[0].fid in ["Youdao", "FastEmbed"]: | |||||
| model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""} | model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""} | ||||
| if not model_config: | if not model_config: | ||||
| if llm_name == "flag-embedding": | if llm_name == "flag-embedding": |
| from api.db.db_models import Task, Document, Knowledgebase, Tenant | from api.db.db_models import Task, Document, Knowledgebase, Tenant | ||||
| from api.db.services.common_service import CommonService | from api.db.services.common_service import CommonService | ||||
| from api.db.services.document_service import DocumentService | from api.db.services.document_service import DocumentService | ||||
| from api.utils import current_timestamp | |||||
| class TaskService(CommonService): | class TaskService(CommonService): | ||||
| cls.model.id == docs[0]["id"]).execute() | cls.model.id == docs[0]["id"]).execute() | ||||
| return docs | return docs | ||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def get_ongoing_doc_name(cls): | |||||
| with DB.lock("get_task", -1): | |||||
| docs = cls.model.select(*[Document.kb_id, Document.location]) \ | |||||
| .join(Document, on=(cls.model.doc_id == Document.id)) \ | |||||
| .where( | |||||
| Document.status == StatusEnum.VALID.value, | |||||
| Document.run == TaskStatus.RUNNING.value, | |||||
| ~(Document.type == FileType.VIRTUAL.value), | |||||
| cls.model.progress >= 0, | |||||
| cls.model.progress < 1, | |||||
| cls.model.create_time >= current_timestamp() - 180000 | |||||
| ) | |||||
| docs = list(docs.dicts()) | |||||
| if not docs: return [] | |||||
| return list(set([(d["kb_id"], d["location"]) for d in docs])) | |||||
| @classmethod | @classmethod | ||||
| @DB.connection_context() | @DB.connection_context() | ||||
| def do_cancel(cls, id): | def do_cancel(cls, id): |
| self.updown_cnt_mdl.set_param({"device": "cuda"}) | self.updown_cnt_mdl.set_param({"device": "cuda"}) | ||||
| try: | try: | ||||
| model_dir = os.path.join( | model_dir = os.path.join( | ||||
| get_project_base_directory(), | |||||
| "rag/res/deepdoc") | |||||
| get_project_base_directory(), | |||||
| "rag/res/deepdoc") | |||||
| self.updown_cnt_mdl.load_model(os.path.join( | self.updown_cnt_mdl.load_model(os.path.join( | ||||
| model_dir, "updown_concat_xgb.model")) | model_dir, "updown_concat_xgb.model")) | ||||
| except Exception as e: | except Exception as e: | ||||
| self.updown_cnt_mdl.load_model(os.path.join( | self.updown_cnt_mdl.load_model(os.path.join( | ||||
| model_dir, "updown_concat_xgb.model")) | model_dir, "updown_concat_xgb.model")) | ||||
| self.page_from = 0 | self.page_from = 0 | ||||
| """ | """ | ||||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | If you have trouble downloading HuggingFace models, -_^ this might help!! | ||||
| def _y_dis( | def _y_dis( | ||||
| self, a, b): | self, a, b): | ||||
| return ( | return ( | ||||
| b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2 | |||||
| b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2 | |||||
| def _match_proj(self, b): | def _match_proj(self, b): | ||||
| proj_patt = [ | proj_patt = [ | ||||
| tks_down = huqie.qie(down["text"][:LEN]).split(" ") | tks_down = huqie.qie(down["text"][:LEN]).split(" ") | ||||
| tks_up = huqie.qie(up["text"][-LEN:]).split(" ") | tks_up = huqie.qie(up["text"][-LEN:]).split(" ") | ||||
| tks_all = up["text"][-LEN:].strip() \ | tks_all = up["text"][-LEN:].strip() \ | ||||
| + (" " if re.match(r"[a-zA-Z0-9]+", | |||||
| up["text"][-1] + down["text"][0]) else "") \ | |||||
| + down["text"][:LEN].strip() | |||||
| + (" " if re.match(r"[a-zA-Z0-9]+", | |||||
| up["text"][-1] + down["text"][0]) else "") \ | |||||
| + down["text"][:LEN].strip() | |||||
| tks_all = huqie.qie(tks_all).split(" ") | tks_all = huqie.qie(tks_all).split(" ") | ||||
| fea = [ | fea = [ | ||||
| up.get("R", -1) == down.get("R", -1), | up.get("R", -1) == down.get("R", -1), | ||||
| True if re.search(r"[,,][^。.]+$", up["text"]) else False, | True if re.search(r"[,,][^。.]+$", up["text"]) else False, | ||||
| True if re.search(r"[,,][^。.]+$", up["text"]) else False, | True if re.search(r"[,,][^。.]+$", up["text"]) else False, | ||||
| True if re.search(r"[\((][^\))]+$", up["text"]) | True if re.search(r"[\((][^\))]+$", up["text"]) | ||||
| and re.search(r"[\))]", down["text"]) else False, | |||||
| and re.search(r"[\))]", down["text"]) else False, | |||||
| self._match_proj(down), | self._match_proj(down), | ||||
| True if re.match(r"[A-Z]", down["text"]) else False, | True if re.match(r"[A-Z]", down["text"]) else False, | ||||
| True if re.match(r"[A-Z]", up["text"][-1]) else False, | True if re.match(r"[A-Z]", up["text"][-1]) else False, | ||||
| continue | continue | ||||
| for tb in tbls: # for table | for tb in tbls: # for table | ||||
| left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \ | left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \ | ||||
| tb["x1"] + MARGIN, tb["bottom"] + MARGIN | |||||
| tb["x1"] + MARGIN, tb["bottom"] + MARGIN | |||||
| left *= ZM | left *= ZM | ||||
| top *= ZM | top *= ZM | ||||
| right *= ZM | right *= ZM | ||||
| for b in bxs: | for b in bxs: | ||||
| if not b["text"]: | if not b["text"]: | ||||
| left, right, top, bott = b["x0"] * ZM, b["x1"] * \ | left, right, top, bott = b["x0"] * ZM, b["x1"] * \ | ||||
| ZM, b["top"] * ZM, b["bottom"] * ZM | |||||
| ZM, b["top"] * ZM, b["bottom"] * ZM | |||||
| b["text"] = self.ocr.recognize(np.array(img), | b["text"] = self.ocr.recognize(np.array(img), | ||||
| np.array([[left, top], [right, top], [right, bott], [left, bott]], | np.array([[left, top], [right, top], [right, bott], [left, bott]], | ||||
| dtype=np.float32)) | dtype=np.float32)) | ||||
| i += 1 | i += 1 | ||||
| continue | continue | ||||
| lout_no = str(self.boxes[i]["page_number"]) + \ | lout_no = str(self.boxes[i]["page_number"]) + \ | ||||
| "-" + str(self.boxes[i]["layoutno"]) | |||||
| "-" + str(self.boxes[i]["layoutno"]) | |||||
| if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", | if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", | ||||
| "title", | "title", | ||||
| "figure caption", | "figure caption", | ||||
| self.outlines.append((a["/Title"], depth)) | self.outlines.append((a["/Title"], depth)) | ||||
| continue | continue | ||||
| dfs(a, depth + 1) | dfs(a, depth + 1) | ||||
| dfs(outlines, 0) | dfs(outlines, 0) | ||||
| except Exception as e: | except Exception as e: | ||||
| logging.warning(f"Outlines exception: {e}") | logging.warning(f"Outlines exception: {e}") | ||||
| logging.info("Images converted.") | logging.info("Images converted.") | ||||
| self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join( | self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join( | ||||
| random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in | random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in | ||||
| range(len(self.page_chars))] | |||||
| range(len(self.page_chars))] | |||||
| if sum([1 if e else 0 for e in self.is_english]) > len( | if sum([1 if e else 0 for e in self.is_english]) > len( | ||||
| self.page_images) / 2: | self.page_images) / 2: | ||||
| self.is_english = True | self.is_english = True | ||||
| j += 1 | j += 1 | ||||
| self.__ocr(i + 1, img, chars, zoomin) | self.__ocr(i + 1, img, chars, zoomin) | ||||
| #if callback: | |||||
| # callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="") | |||||
| #print("OCR:", timer()-st) | |||||
| if callback and i % 6 == 5: | |||||
| callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="") | |||||
| # print("OCR:", timer()-st) | |||||
| if not self.is_english and not any( | if not self.is_english and not any( | ||||
| [c for c in self.page_chars]) and self.boxes: | [c for c in self.page_chars]) and self.boxes: | ||||
| left, right, top, bottom = float(left), float( | left, right, top, bottom = float(left), float( | ||||
| right), float(top), float(bottom) | right), float(top), float(bottom) | ||||
| poss.append(([int(p) - 1 for p in pn.split("-")], | poss.append(([int(p) - 1 for p in pn.split("-")], | ||||
| left, right, top, bottom)) | |||||
| left, right, top, bottom)) | |||||
| if not poss: | if not poss: | ||||
| if need_position: | if need_position: | ||||
| return None, None | return None, None | ||||
| self.page_images[pns[0]].crop((left * ZM, top * ZM, | self.page_images[pns[0]].crop((left * ZM, top * ZM, | ||||
| right * | right * | ||||
| ZM, min( | ZM, min( | ||||
| bottom, self.page_images[pns[0]].size[1]) | |||||
| bottom, self.page_images[pns[0]].size[1]) | |||||
| )) | )) | ||||
| ) | ) | ||||
| if 0 < ii < len(poss) - 1: | if 0 < ii < len(poss) - 1: |
| "Tongyi-Qianwen": HuEmbedding, #QWenEmbed, | "Tongyi-Qianwen": HuEmbedding, #QWenEmbed, | ||||
| "ZHIPU-AI": ZhipuEmbed, | "ZHIPU-AI": ZhipuEmbed, | ||||
| "FastEmbed": FastEmbed, | "FastEmbed": FastEmbed, | ||||
| "QAnything": QAnythingEmbed | |||||
| "Youdao": YoudaoEmbed | |||||
| } | } | ||||
| return np.array(res.data[0].embedding), res.usage.total_tokens | return np.array(res.data[0].embedding), res.usage.total_tokens | ||||
| class QAnythingEmbed(Base): | |||||
| class YoudaoEmbed(Base): | |||||
| _client = None | _client = None | ||||
| def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs): | def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs): | ||||
| from BCEmbedding import EmbeddingModel as qanthing | from BCEmbedding import EmbeddingModel as qanthing | ||||
| if not QAnythingEmbed._client: | |||||
| if not YoudaoEmbed._client: | |||||
| try: | try: | ||||
| print("LOADING BCE...") | print("LOADING BCE...") | ||||
| QAnythingEmbed._client = qanthing(model_name_or_path=os.path.join( | |||||
| YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join( | |||||
| get_project_base_directory(), | get_project_base_directory(), | ||||
| "rag/res/bce-embedding-base_v1")) | "rag/res/bce-embedding-base_v1")) | ||||
| except Exception as e: | except Exception as e: | ||||
| QAnythingEmbed._client = qanthing( | |||||
| YoudaoEmbed._client = qanthing( | |||||
| model_name_or_path=model_name.replace( | model_name_or_path=model_name.replace( | ||||
| "maidalun1020", "InfiniFlow")) | "maidalun1020", "InfiniFlow")) | ||||
| for t in texts: | for t in texts: | ||||
| token_count += num_tokens_from_string(t) | token_count += num_tokens_from_string(t) | ||||
| for i in range(0, len(texts), batch_size): | for i in range(0, len(texts), batch_size): | ||||
| embds = QAnythingEmbed._client.encode(texts[i:i + batch_size]) | |||||
| embds = YoudaoEmbed._client.encode(texts[i:i + batch_size]) | |||||
| res.extend(embds) | res.extend(embds) | ||||
| return np.array(res), token_count | return np.array(res), token_count | ||||
| def encode_queries(self, text): | def encode_queries(self, text): | ||||
| embds = QAnythingEmbed._client.encode([text]) | |||||
| embds = YoudaoEmbed._client.encode([text]) | |||||
| return np.array(embds[0]), num_tokens_from_string(text) | return np.array(embds[0]), num_tokens_from_string(text) |
| import random | |||||
| import time | |||||
| import traceback | |||||
| from api.db.db_models import close_connection | |||||
| from api.db.services.task_service import TaskService | |||||
| from rag.utils import MINIO | |||||
| from rag.utils.redis_conn import REDIS_CONN | |||||
| def collect(): | |||||
| doc_locations = TaskService.get_ongoing_doc_name() | |||||
| #print(tasks) | |||||
| if len(doc_locations) == 0: | |||||
| time.sleep(1) | |||||
| return | |||||
| return doc_locations | |||||
| def main(): | |||||
| locations = collect() | |||||
| if not locations:return | |||||
| print("TASKS:", len(locations)) | |||||
| for kb_id, loc in locations: | |||||
| try: | |||||
| if REDIS_CONN.is_alive(): | |||||
| try: | |||||
| key = "{}/{}".format(kb_id, loc) | |||||
| if REDIS_CONN.exist(key):continue | |||||
| file_bin = MINIO.get(kb_id, loc) | |||||
| REDIS_CONN.transaction(key, file_bin, 12 * 60) | |||||
| print("CACHE:", loc) | |||||
| except Exception as e: | |||||
| traceback.print_stack(e) | |||||
| except Exception as e: | |||||
| traceback.print_stack(e) | |||||
| if __name__ == "__main__": | |||||
| while True: | |||||
| main() | |||||
| close_connection() | |||||
| time.sleep(1) |
| info = { | info = { | ||||
| "process_duation": datetime.timestamp( | "process_duation": datetime.timestamp( | ||||
| datetime.now()) - | datetime.now()) - | ||||
| d["process_begin_at"].timestamp(), | |||||
| d["process_begin_at"].timestamp(), | |||||
| "run": status} | "run": status} | ||||
| if prg != 0: | if prg != 0: | ||||
| info["progress"] = prg | info["progress"] = prg |
| global MINIO | global MINIO | ||||
| if REDIS_CONN.is_alive(): | if REDIS_CONN.is_alive(): | ||||
| try: | try: | ||||
| for _ in range(30): | |||||
| if REDIS_CONN.exist("{}/{}".format(bucket, name)): | |||||
| time.sleep(1) | |||||
| break | |||||
| time.sleep(1) | |||||
| r = REDIS_CONN.get("{}/{}".format(bucket, name)) | r = REDIS_CONN.get("{}/{}".format(bucket, name)) | ||||
| if r: return r | if r: return r | ||||
| cron_logger.warning("Cache missing: {}".format(name)) | |||||
| except Exception as e: | except Exception as e: | ||||
| cron_logger.warning("Get redis[EXCEPTION]:" + str(e)) | cron_logger.warning("Get redis[EXCEPTION]:" + str(e)) | ||||
| return MINIO.get(bucket, name) | return MINIO.get(bucket, name) |
| except Exception as e: | except Exception as e: | ||||
| minio_logger.error(f"Fail rm {bucket}/{fnm}: " + str(e)) | minio_logger.error(f"Fail rm {bucket}/{fnm}: " + str(e)) | ||||
| def get(self, bucket, fnm): | def get(self, bucket, fnm): | ||||
| for _ in range(1): | for _ in range(1): | ||||
| try: | try: |
| def is_alive(self): | def is_alive(self): | ||||
| return self.REDIS is not None | return self.REDIS is not None | ||||
| def exist(self, k): | |||||
| if not self.REDIS: return | |||||
| try: | |||||
| return self.REDIS.exists(k) | |||||
| except Exception as e: | |||||
| logging.warning("[EXCEPTION]exist" + str(k) + "||" + str(e)) | |||||
| self.__open__() | |||||
| def get(self, k): | def get(self, k): | ||||
| if not self.REDIS: return | if not self.REDIS: return | ||||
| try: | try: | ||||
| self.__open__() | self.__open__() | ||||
| return False | return False | ||||
| def transaction(self, key, value, exp=3600): | |||||
| try: | |||||
| pipeline = self.REDIS.pipeline(transaction=True) | |||||
| pipeline.set(key, value, exp, nx=True) | |||||
| pipeline.execute() | |||||
| return True | |||||
| except Exception as e: | |||||
| logging.warning("[EXCEPTION]set" + str(key) + "||" + str(e)) | |||||
| self.__open__() | |||||
| return False | |||||
| REDIS_CONN = RedisDB() | REDIS_CONN = RedisDB() |