| @@ -28,8 +28,6 @@ from api.utils import CustomJSONEncoder | |||
| from flask_session import Session | |||
| from flask_login import LoginManager | |||
| from api.settings import RetCode, SECRET_KEY, stat_logger | |||
| from api.hook import HookManager | |||
| from api.hook.common.parameters import AuthenticationParameters, ClientAuthenticationParameters | |||
| from api.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger | |||
| from api.utils.api_utils import get_json_result, server_error_response | |||
| from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer | |||
| @@ -96,37 +94,7 @@ client_urls_prefix = [ | |||
| ] | |||
| def client_authentication_before_request(): | |||
| result = HookManager.client_authentication(ClientAuthenticationParameters( | |||
| request.full_path, request.headers, | |||
| request.form, request.data, request.json, | |||
| )) | |||
| if result.code != RetCode.SUCCESS: | |||
| return get_json_result(result.code, result.message) | |||
| def site_authentication_before_request(): | |||
| for url_prefix in client_urls_prefix: | |||
| if request.path.startswith(url_prefix): | |||
| return | |||
| result = HookManager.site_authentication(AuthenticationParameters( | |||
| request.headers.get('site_signature'), | |||
| request.json, | |||
| )) | |||
| if result.code != RetCode.SUCCESS: | |||
| return get_json_result(result.code, result.message) | |||
| @app.before_request | |||
| def authentication_before_request(): | |||
| if CLIENT_AUTHENTICATION: | |||
| return client_authentication_before_request() | |||
| if SITE_AUTHENTICATION: | |||
| return site_authentication_before_request() | |||
| @login_manager.request_loader | |||
| def load_user(web_request): | |||
| @@ -57,7 +57,7 @@ def list(): | |||
| for id in sres.ids: | |||
| d = { | |||
| "chunk_id": id, | |||
| "content_ltks": rmSpace(sres.highlight[id]) if question else sres.field[id]["content_ltks"], | |||
| "content_with_weight": rmSpace(sres.highlight[id]) if question else sres.field[id]["content_with_weight"], | |||
| "doc_id": sres.field[id]["doc_id"], | |||
| "docnm_kwd": sres.field[id]["docnm_kwd"], | |||
| "important_kwd": sres.field[id].get("important_kwd", []), | |||
| @@ -134,7 +134,7 @@ def set(): | |||
| q, a = rmPrefix(arr[0]), rmPrefix[arr[1]] | |||
| d = beAdoc(d, arr[0], arr[1], not any([huqie.is_chinese(t) for t in q+a])) | |||
| v, c = embd_mdl.encode([doc.name, req["content_ltks"]]) | |||
| v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) | |||
| v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] | |||
| d["q_%d_vec" % len(v)] = v.tolist() | |||
| ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) | |||
| @@ -175,13 +175,13 @@ def rm(): | |||
| @manager.route('/create', methods=['POST']) | |||
| @login_required | |||
| @validate_request("doc_id", "content_ltks") | |||
| @validate_request("doc_id", "content_with_weight") | |||
| def create(): | |||
| req = request.json | |||
| md5 = hashlib.md5() | |||
| md5.update((req["content_ltks"] + req["doc_id"]).encode("utf-8")) | |||
| md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8")) | |||
| chunck_id = md5.hexdigest() | |||
| d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_ltks"])} | |||
| d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_with_weight"])} | |||
| d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) | |||
| d["important_kwd"] = req.get("important_kwd", []) | |||
| d["important_tks"] = huqie.qie(" ".join(req.get("important_kwd", []))) | |||
| @@ -201,7 +201,7 @@ def create(): | |||
| embd_mdl = TenantLLMService.model_instance( | |||
| tenant_id, LLMType.EMBEDDING.value) | |||
| v, c = embd_mdl.encode([doc.name, req["content_ltks"]]) | |||
| v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) | |||
| DocumentService.increment_chunk_num(req["doc_id"], doc.kb_id, c, 1, 0) | |||
| v = 0.1 * v[0] + 0.9 * v[1] | |||
| d["q_%d_vec" % len(v)] = v.tolist() | |||
| @@ -175,7 +175,7 @@ def chat(dialog, messages, **kwargs): | |||
| chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) | |||
| kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold, | |||
| dialog.vector_similarity_weight, top=1024, aggs=False) | |||
| knowledges = [ck["content_ltks"] for ck in kbinfos["chunks"]] | |||
| knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] | |||
| if not knowledges and prompt_config["empty_response"]: | |||
| return {"answer": prompt_config["empty_response"], "retrieval": kbinfos} | |||
| @@ -73,6 +73,7 @@ def upload(): | |||
| "id": get_uuid(), | |||
| "kb_id": kb.id, | |||
| "parser_id": kb.parser_id, | |||
| "parser_config": kb.parser_config, | |||
| "created_by": current_user.id, | |||
| "type": filename_type(filename), | |||
| "name": filename, | |||
| @@ -108,6 +109,7 @@ def create(): | |||
| "id": get_uuid(), | |||
| "kb_id": kb.id, | |||
| "parser_id": kb.parser_id, | |||
| "parser_config": kb.parser_config, | |||
| "created_by": current_user.id, | |||
| "type": FileType.VIRTUAL, | |||
| "name": req["name"], | |||
| @@ -128,8 +130,8 @@ def list(): | |||
| data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR) | |||
| keywords = request.args.get("keywords", "") | |||
| page_number = request.args.get("page", 1) | |||
| items_per_page = request.args.get("page_size", 15) | |||
| page_number = int(request.args.get("page", 1)) | |||
| items_per_page = int(request.args.get("page_size", 15)) | |||
| orderby = request.args.get("orderby", "create_time") | |||
| desc = request.args.get("desc", True) | |||
| try: | |||
| @@ -214,7 +216,9 @@ def run(): | |||
| req = request.json | |||
| try: | |||
| for id in req["doc_ids"]: | |||
| DocumentService.update_by_id(id, {"run": str(req["run"]), "progress": 0}) | |||
| info = {"run": str(req["run"]), "progress": 0} | |||
| if str(req["run"]) == TaskStatus.RUNNING.value:info["progress_msg"] = "" | |||
| DocumentService.update_by_id(id, info) | |||
| if str(req["run"]) == TaskStatus.CANCEL.value: | |||
| tenant_id = DocumentService.get_tenant_id(id) | |||
| if not tenant_id: | |||
| @@ -29,7 +29,7 @@ from api.utils.api_utils import get_json_result | |||
| @manager.route('/create', methods=['post']) | |||
| @login_required | |||
| @validate_request("name", "description", "permission", "parser_id") | |||
| @validate_request("name") | |||
| def create(): | |||
| req = request.json | |||
| req["name"] = req["name"].strip() | |||
| @@ -77,3 +77,4 @@ class ParserType(StrEnum): | |||
| RESUME = "resume" | |||
| BOOK = "book" | |||
| QA = "qa" | |||
| TABLE = "table" | |||
| @@ -29,7 +29,7 @@ from peewee import ( | |||
| ) | |||
| from playhouse.pool import PooledMySQLDatabase | |||
| from api.db import SerializedType | |||
| from api.db import SerializedType, ParserType | |||
| from api.settings import DATABASE, stat_logger, SECRET_KEY | |||
| from api.utils.log_utils import getLogger | |||
| from api import utils | |||
| @@ -381,7 +381,8 @@ class Tenant(DataBaseModel): | |||
| embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID") | |||
| asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID") | |||
| img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID") | |||
| parser_ids = CharField(max_length=128, null=False, help_text="default image to text model ID") | |||
| parser_ids = CharField(max_length=128, null=False, help_text="document processors") | |||
| credit = IntegerField(default=512) | |||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | |||
| class Meta: | |||
| @@ -472,7 +473,8 @@ class Knowledgebase(DataBaseModel): | |||
| similarity_threshold = FloatField(default=0.2) | |||
| vector_similarity_weight = FloatField(default=0.3) | |||
| parser_id = CharField(max_length=32, null=False, help_text="default parser ID") | |||
| parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.GENERAL.value) | |||
| parser_config = JSONField(null=False, default={"from_page":0, "to_page": 100000}) | |||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | |||
| def __str__(self): | |||
| @@ -487,6 +489,7 @@ class Document(DataBaseModel): | |||
| thumbnail = TextField(null=True, help_text="thumbnail base64 string") | |||
| kb_id = CharField(max_length=256, null=False, index=True) | |||
| parser_id = CharField(max_length=32, null=False, help_text="default parser ID") | |||
| parser_config = JSONField(null=False, default={"from_page":0, "to_page": 100000}) | |||
| source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from") | |||
| type = CharField(max_length=32, null=False, help_text="file extension") | |||
| created_by = CharField(max_length=32, null=False, help_text="who created it") | |||
| @@ -1,157 +0,0 @@ | |||
| # | |||
| # Copyright 2021 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import abc | |||
| import json | |||
| import time | |||
| from functools import wraps | |||
| from shortuuid import ShortUUID | |||
| from api.versions import get_rag_version | |||
| from api.errors.error_services import * | |||
| from api.settings import ( | |||
| GRPC_PORT, HOST, HTTP_PORT, | |||
| RANDOM_INSTANCE_ID, stat_logger, | |||
| ) | |||
| instance_id = ShortUUID().random(length=8) if RANDOM_INSTANCE_ID else f'flow-{HOST}-{HTTP_PORT}' | |||
| server_instance = ( | |||
| f'{HOST}:{GRPC_PORT}', | |||
| json.dumps({ | |||
| 'instance_id': instance_id, | |||
| 'timestamp': round(time.time() * 1000), | |||
| 'version': get_rag_version() or '', | |||
| 'host': HOST, | |||
| 'grpc_port': GRPC_PORT, | |||
| 'http_port': HTTP_PORT, | |||
| }), | |||
| ) | |||
| def check_service_supported(method): | |||
| """Decorator to check if `service_name` is supported. | |||
| The attribute `supported_services` MUST be defined in class. | |||
| The first and second arguments of `method` MUST be `self` and `service_name`. | |||
| :param Callable method: The class method. | |||
| :return: The inner wrapper function. | |||
| :rtype: Callable | |||
| """ | |||
| @wraps(method) | |||
| def magic(self, service_name, *args, **kwargs): | |||
| if service_name not in self.supported_services: | |||
| raise ServiceNotSupported(service_name=service_name) | |||
| return method(self, service_name, *args, **kwargs) | |||
| return magic | |||
| class ServicesDB(abc.ABC): | |||
| """Database for storage service urls. | |||
| Abstract base class for the real backends. | |||
| """ | |||
| @property | |||
| @abc.abstractmethod | |||
| def supported_services(self): | |||
| """The names of supported services. | |||
| The returned list SHOULD contain `ragflow` (model download) and `servings` (RAG-Serving). | |||
| :return: The service names. | |||
| :rtype: list | |||
| """ | |||
| pass | |||
| @abc.abstractmethod | |||
| def _get_serving(self): | |||
| pass | |||
| def get_serving(self): | |||
| try: | |||
| return self._get_serving() | |||
| except ServicesError as e: | |||
| stat_logger.exception(e) | |||
| return [] | |||
| @abc.abstractmethod | |||
| def _insert(self, service_name, service_url, value=''): | |||
| pass | |||
| @check_service_supported | |||
| def insert(self, service_name, service_url, value=''): | |||
| """Insert a service url to database. | |||
| :param str service_name: The service name. | |||
| :param str service_url: The service url. | |||
| :return: None | |||
| """ | |||
| try: | |||
| self._insert(service_name, service_url, value) | |||
| except ServicesError as e: | |||
| stat_logger.exception(e) | |||
| @abc.abstractmethod | |||
| def _delete(self, service_name, service_url): | |||
| pass | |||
| @check_service_supported | |||
| def delete(self, service_name, service_url): | |||
| """Delete a service url from database. | |||
| :param str service_name: The service name. | |||
| :param str service_url: The service url. | |||
| :return: None | |||
| """ | |||
| try: | |||
| self._delete(service_name, service_url) | |||
| except ServicesError as e: | |||
| stat_logger.exception(e) | |||
| def register_flow(self): | |||
| """Call `self.insert` for insert the flow server address to databae. | |||
| :return: None | |||
| """ | |||
| self.insert('flow-server', *server_instance) | |||
| def unregister_flow(self): | |||
| """Call `self.delete` for delete the flow server address from databae. | |||
| :return: None | |||
| """ | |||
| self.delete('flow-server', server_instance[0]) | |||
| @abc.abstractmethod | |||
| def _get_urls(self, service_name, with_values=False): | |||
| pass | |||
| @check_service_supported | |||
| def get_urls(self, service_name, with_values=False): | |||
| """Query service urls from database. The urls may belong to other nodes. | |||
| Currently, only `ragflow` (model download) urls and `servings` (RAG-Serving) urls are supported. | |||
| `ragflow` is a url containing scheme, host, port and path, | |||
| while `servings` only contains host and port. | |||
| :param str service_name: The service name. | |||
| :return: The service urls. | |||
| :rtype: list | |||
| """ | |||
| try: | |||
| return self._get_urls(service_name, with_values) | |||
| except ServicesError as e: | |||
| stat_logger.exception(e) | |||
| return [] | |||
| @@ -63,7 +63,7 @@ class DocumentService(CommonService): | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_newly_uploaded(cls, tm, mod=0, comm=1, items_per_page=64): | |||
| fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.type, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] | |||
| fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.parser_config, cls.model.name, cls.model.type, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] | |||
| docs = cls.model.select(*fields) \ | |||
| .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \ | |||
| .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\ | |||
| @@ -52,7 +52,8 @@ class KnowledgebaseService(CommonService): | |||
| cls.model.doc_num, | |||
| cls.model.token_num, | |||
| cls.model.chunk_num, | |||
| cls.model.parser_id] | |||
| cls.model.parser_id, | |||
| cls.model.parser_config] | |||
| kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where( | |||
| (cls.model.id == kb_id), | |||
| (cls.model.status == StatusEnum.VALID.value) | |||
| @@ -27,7 +27,7 @@ class TaskService(CommonService): | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_tasks(cls, tm, mod=0, comm=1, items_per_page=64): | |||
| fields = [cls.model.id, cls.model.doc_id, cls.model.from_page,cls.model.to_page, Document.kb_id, Document.parser_id, Document.name, Document.type, Document.location, Document.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] | |||
| fields = [cls.model.id, cls.model.doc_id, cls.model.from_page,cls.model.to_page, Document.kb_id, Document.parser_id, Document.parser_config, Document.name, Document.type, Document.location, Document.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] | |||
| docs = cls.model.select(*fields) \ | |||
| .join(Document, on=(cls.model.doc_id == Document.id)) \ | |||
| .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \ | |||
| @@ -53,3 +53,13 @@ class TaskService(CommonService): | |||
| except Exception as e: | |||
| pass | |||
| return True | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def update_progress(cls, id, info): | |||
| cls.model.update(progress_msg=cls.model.progress_msg + "\n"+info["progress_msg"]).where( | |||
| cls.model.id == id).execute() | |||
| if "progress" in info: | |||
| cls.model.update(progress=info["progress"]).where( | |||
| cls.model.id == id).execute() | |||
| @@ -92,6 +92,12 @@ class TenantService(CommonService): | |||
| .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role==UserTenantRole.NORMAL.value)))\ | |||
| .where(cls.model.status == StatusEnum.VALID.value).dicts()) | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def decrease(cls, user_id, num): | |||
| num = cls.model.update(credit=cls.model.credit - num).where( | |||
| cls.model.id == user_id).execute() | |||
| if num == 0: raise LookupError("Tenant not found which is supposed to be there") | |||
| class UserTenantService(CommonService): | |||
| model = UserTenant | |||
| @@ -1,10 +0,0 @@ | |||
| from .general_error import * | |||
| class RagFlowError(Exception): | |||
| message = 'Unknown Rag Flow Error' | |||
| def __init__(self, message=None, *args, **kwargs): | |||
| message = str(message) if message is not None else self.message | |||
| message = message.format(*args, **kwargs) | |||
| super().__init__(message) | |||
| @@ -1,13 +0,0 @@ | |||
| from api.errors import RagFlowError | |||
| __all__ = ['ServicesError', 'ServiceNotSupported', 'ZooKeeperNotConfigured', | |||
| 'MissingZooKeeperUsernameOrPassword', 'ZooKeeperBackendError'] | |||
| class ServicesError(RagFlowError): | |||
| message = 'Unknown services error' | |||
| class ServiceNotSupported(ServicesError): | |||
| message = 'The service {service_name} is not supported' | |||
| @@ -1,21 +0,0 @@ | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| class ParameterError(Exception): | |||
| pass | |||
| class PassError(Exception): | |||
| pass | |||
| @@ -1,57 +0,0 @@ | |||
| import importlib | |||
| from api.hook.common.parameters import SignatureParameters, AuthenticationParameters, \ | |||
| SignatureReturn, AuthenticationReturn, PermissionReturn, ClientAuthenticationReturn, ClientAuthenticationParameters | |||
| from api.settings import HOOK_MODULE, stat_logger,RetCode | |||
| class HookManager: | |||
| SITE_SIGNATURE = [] | |||
| SITE_AUTHENTICATION = [] | |||
| CLIENT_AUTHENTICATION = [] | |||
| PERMISSION_CHECK = [] | |||
| @staticmethod | |||
| def init(): | |||
| if HOOK_MODULE is not None: | |||
| for modules in HOOK_MODULE.values(): | |||
| for module in modules.split(";"): | |||
| try: | |||
| importlib.import_module(module) | |||
| except Exception as e: | |||
| stat_logger.exception(e) | |||
| @staticmethod | |||
| def register_site_signature_hook(func): | |||
| HookManager.SITE_SIGNATURE.append(func) | |||
| @staticmethod | |||
| def register_site_authentication_hook(func): | |||
| HookManager.SITE_AUTHENTICATION.append(func) | |||
| @staticmethod | |||
| def register_client_authentication_hook(func): | |||
| HookManager.CLIENT_AUTHENTICATION.append(func) | |||
| @staticmethod | |||
| def register_permission_check_hook(func): | |||
| HookManager.PERMISSION_CHECK.append(func) | |||
| @staticmethod | |||
| def client_authentication(parm: ClientAuthenticationParameters) -> ClientAuthenticationReturn: | |||
| if HookManager.CLIENT_AUTHENTICATION: | |||
| return HookManager.CLIENT_AUTHENTICATION[0](parm) | |||
| return ClientAuthenticationReturn() | |||
| @staticmethod | |||
| def site_signature(parm: SignatureParameters) -> SignatureReturn: | |||
| if HookManager.SITE_SIGNATURE: | |||
| return HookManager.SITE_SIGNATURE[0](parm) | |||
| return SignatureReturn() | |||
| @staticmethod | |||
| def site_authentication(parm: AuthenticationParameters) -> AuthenticationReturn: | |||
| if HookManager.SITE_AUTHENTICATION: | |||
| return HookManager.SITE_AUTHENTICATION[0](parm) | |||
| return AuthenticationReturn() | |||
| @@ -1,29 +0,0 @@ | |||
| import requests | |||
| from api.db.service_registry import ServiceRegistry | |||
| from api.settings import RegistryServiceName | |||
| from api.hook import HookManager | |||
| from api.hook.common.parameters import ClientAuthenticationParameters, ClientAuthenticationReturn | |||
| from api.settings import HOOK_SERVER_NAME | |||
| @HookManager.register_client_authentication_hook | |||
| def authentication(parm: ClientAuthenticationParameters) -> ClientAuthenticationReturn: | |||
| service_list = ServiceRegistry.load_service( | |||
| server_name=HOOK_SERVER_NAME, | |||
| service_name=RegistryServiceName.CLIENT_AUTHENTICATION.value | |||
| ) | |||
| if not service_list: | |||
| raise Exception(f"client authentication error: no found server" | |||
| f" {HOOK_SERVER_NAME} service client_authentication") | |||
| service = service_list[0] | |||
| response = getattr(requests, service.f_method.lower(), None)( | |||
| url=service.f_url, | |||
| json=parm.to_dict() | |||
| ) | |||
| if response.status_code != 200: | |||
| raise Exception( | |||
| f"client authentication error: request authentication url failed, status code {response.status_code}") | |||
| elif response.json().get("code") != 0: | |||
| return ClientAuthenticationReturn(code=response.json().get("code"), message=response.json().get("msg")) | |||
| return ClientAuthenticationReturn() | |||
| @@ -1,25 +0,0 @@ | |||
| import requests | |||
| from api.db.service_registry import ServiceRegistry | |||
| from api.settings import RegistryServiceName | |||
| from api.hook import HookManager | |||
| from api.hook.common.parameters import PermissionCheckParameters, PermissionReturn | |||
| from api.settings import HOOK_SERVER_NAME | |||
| @HookManager.register_permission_check_hook | |||
| def permission(parm: PermissionCheckParameters) -> PermissionReturn: | |||
| service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, service_name=RegistryServiceName.PERMISSION_CHECK.value) | |||
| if not service_list: | |||
| raise Exception(f"permission check error: no found server {HOOK_SERVER_NAME} service permission") | |||
| service = service_list[0] | |||
| response = getattr(requests, service.f_method.lower(), None)( | |||
| url=service.f_url, | |||
| json=parm.to_dict() | |||
| ) | |||
| if response.status_code != 200: | |||
| raise Exception( | |||
| f"permission check error: request permission url failed, status code {response.status_code}") | |||
| elif response.json().get("code") != 0: | |||
| return PermissionReturn(code=response.json().get("code"), message=response.json().get("msg")) | |||
| return PermissionReturn() | |||
| @@ -1,49 +0,0 @@ | |||
| import requests | |||
| from api.db.service_registry import ServiceRegistry | |||
| from api.settings import RegistryServiceName | |||
| from api.hook import HookManager | |||
| from api.hook.common.parameters import SignatureParameters, AuthenticationParameters, AuthenticationReturn,\ | |||
| SignatureReturn | |||
| from api.settings import HOOK_SERVER_NAME, PARTY_ID | |||
| @HookManager.register_site_signature_hook | |||
| def signature(parm: SignatureParameters) -> SignatureReturn: | |||
| service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, service_name=RegistryServiceName.SIGNATURE.value) | |||
| if not service_list: | |||
| raise Exception(f"signature error: no found server {HOOK_SERVER_NAME} service signature") | |||
| service = service_list[0] | |||
| response = getattr(requests, service.f_method.lower(), None)( | |||
| url=service.f_url, | |||
| json=parm.to_dict() | |||
| ) | |||
| if response.status_code == 200: | |||
| if response.json().get("code") == 0: | |||
| return SignatureReturn(site_signature=response.json().get("data")) | |||
| else: | |||
| raise Exception(f"signature error: request signature url failed, result: {response.json()}") | |||
| else: | |||
| raise Exception(f"signature error: request signature url failed, status code {response.status_code}") | |||
| @HookManager.register_site_authentication_hook | |||
| def authentication(parm: AuthenticationParameters) -> AuthenticationReturn: | |||
| if not parm.src_party_id or str(parm.src_party_id) == "0": | |||
| parm.src_party_id = PARTY_ID | |||
| service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, | |||
| service_name=RegistryServiceName.SITE_AUTHENTICATION.value) | |||
| if not service_list: | |||
| raise Exception( | |||
| f"site authentication error: no found server {HOOK_SERVER_NAME} service site_authentication") | |||
| service = service_list[0] | |||
| response = getattr(requests, service.f_method.lower(), None)( | |||
| url=service.f_url, | |||
| json=parm.to_dict() | |||
| ) | |||
| if response.status_code != 200: | |||
| raise Exception( | |||
| f"site authentication error: request site_authentication url failed, status code {response.status_code}") | |||
| elif response.json().get("code") != 0: | |||
| return AuthenticationReturn(code=response.json().get("code"), message=response.json().get("msg")) | |||
| return AuthenticationReturn() | |||
| @@ -1,56 +0,0 @@ | |||
| from api.settings import RetCode | |||
| class ParametersBase: | |||
| def to_dict(self): | |||
| d = {} | |||
| for k, v in self.__dict__.items(): | |||
| d[k] = v | |||
| return d | |||
| class ClientAuthenticationParameters(ParametersBase): | |||
| def __init__(self, full_path, headers, form, data, json): | |||
| self.full_path = full_path | |||
| self.headers = headers | |||
| self.form = form | |||
| self.data = data | |||
| self.json = json | |||
| class ClientAuthenticationReturn(ParametersBase): | |||
| def __init__(self, code=RetCode.SUCCESS, message="success"): | |||
| self.code = code | |||
| self.message = message | |||
| class SignatureParameters(ParametersBase): | |||
| def __init__(self, party_id, body): | |||
| self.party_id = party_id | |||
| self.body = body | |||
| class SignatureReturn(ParametersBase): | |||
| def __init__(self, code=RetCode.SUCCESS, site_signature=None): | |||
| self.code = code | |||
| self.site_signature = site_signature | |||
| class AuthenticationParameters(ParametersBase): | |||
| def __init__(self, site_signature, body): | |||
| self.site_signature = site_signature | |||
| self.body = body | |||
| class AuthenticationReturn(ParametersBase): | |||
| def __init__(self, code=RetCode.SUCCESS, message="success"): | |||
| self.code = code | |||
| self.message = message | |||
| class PermissionReturn(ParametersBase): | |||
| def __init__(self, code=RetCode.SUCCESS, message="success"): | |||
| self.code = code | |||
| self.message = message | |||
| @@ -20,12 +20,9 @@ import os | |||
| import signal | |||
| import sys | |||
| import traceback | |||
| from werkzeug.serving import run_simple | |||
| from api.apps import app | |||
| from api.db.runtime_config import RuntimeConfig | |||
| from api.hook import HookManager | |||
| from api.settings import ( | |||
| HOST, HTTP_PORT, access_logger, database_logger, stat_logger, | |||
| ) | |||
| @@ -60,8 +57,6 @@ if __name__ == '__main__': | |||
| RuntimeConfig.init_env() | |||
| RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT) | |||
| HookManager.init() | |||
| peewee_logger = logging.getLogger('peewee') | |||
| peewee_logger.propagate = False | |||
| # rag_arch.common.log.ROpenHandler | |||
| @@ -47,7 +47,7 @@ LLM = get_base_config("llm", {}) | |||
| CHAT_MDL = LLM.get("chat_model", "gpt-3.5-turbo") | |||
| EMBEDDING_MDL = LLM.get("embedding_model", "text-embedding-ada-002") | |||
| ASR_MDL = LLM.get("asr_model", "whisper-1") | |||
| PARSERS = LLM.get("parsers", "General,Resume,Laws,Product Instructions,Books,Paper,Q&A,Programming Code,Power Point,Research Report") | |||
| PARSERS = LLM.get("parsers", "general:General,resume:esume,laws:Laws,manual:Manual,book:Book,paper:Paper,qa:Q&A,presentation:Presentation") | |||
| IMAGE2TEXT_MDL = LLM.get("image2text_model", "gpt-4-vision-preview") | |||
| # distribution | |||
| @@ -3,7 +3,7 @@ import random | |||
| import re | |||
| import numpy as np | |||
| from rag.parser import bullets_category, BULLET_PATTERN, is_english, tokenize, remove_contents_table, \ | |||
| hierarchical_merge, make_colon_as_title, naive_merge | |||
| hierarchical_merge, make_colon_as_title, naive_merge, random_choices | |||
| from rag.nlp import huqie | |||
| from rag.parser.docx_parser import HuDocxParser | |||
| from rag.parser.pdf_parser import HuParser | |||
| @@ -51,7 +51,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k | |||
| doc_parser = HuDocxParser() | |||
| # TODO: table of contents need to be removed | |||
| sections, tbls = doc_parser(binary if binary else filename, from_page=from_page, to_page=to_page) | |||
| remove_contents_table(sections, eng=is_english(random.choices([t for t,_ in sections], k=200))) | |||
| remove_contents_table(sections, eng=is_english(random_choices([t for t,_ in sections], k=200))) | |||
| callback(0.8, "Finish parsing.") | |||
| elif re.search(r"\.pdf$", filename, re.IGNORECASE): | |||
| pdf_parser = Pdf() | |||
| @@ -67,20 +67,20 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k | |||
| l = f.readline() | |||
| if not l:break | |||
| txt += l | |||
| sections = txt.split("\n") | |||
| sections = txt.split("\n") | |||
| sections = [(l,"") for l in sections if l] | |||
| remove_contents_table(sections, eng = is_english(random.choices([t for t,_ in sections], k=200))) | |||
| remove_contents_table(sections, eng = is_english(random_choices([t for t,_ in sections], k=200))) | |||
| callback(0.8, "Finish parsing.") | |||
| else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") | |||
| make_colon_as_title(sections) | |||
| bull = bullets_category([t for t in random.choices([t for t,_ in sections], k=100)]) | |||
| bull = bullets_category([t for t in random_choices([t for t,_ in sections], k=100)]) | |||
| if bull >= 0: cks = hierarchical_merge(bull, sections, 3) | |||
| else: cks = naive_merge(sections, kwargs.get("chunk_token_num", 256), kwargs.get("delimer", "\n。;!?")) | |||
| sections = [t for t, _ in sections] | |||
| # is it English | |||
| eng = is_english(random.choices(sections, k=218)) | |||
| eng = is_english(random_choices(sections, k=218)) | |||
| res = [] | |||
| # add tables | |||
| @@ -86,7 +86,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k | |||
| l = f.readline() | |||
| if not l:break | |||
| txt += l | |||
| sections = txt.split("\n") | |||
| sections = txt.split("\n") | |||
| sections = txt.split("\n") | |||
| sections = [l for l in sections if l] | |||
| callback(0.8, "Finish parsing.") | |||
| else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") | |||
| @@ -52,7 +52,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k | |||
| l = f.readline() | |||
| if not l:break | |||
| txt += l | |||
| sections = txt.split("\n") | |||
| sections = txt.split("\n") | |||
| sections = [(l,"") for l in sections if l] | |||
| callback(0.8, "Finish parsing.") | |||
| else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") | |||
| @@ -1,6 +1,9 @@ | |||
| import copy | |||
| import re | |||
| from collections import Counter | |||
| from api.db import ParserType | |||
| from rag.cv.ppdetection import PPDet | |||
| from rag.parser import tokenize | |||
| from rag.nlp import huqie | |||
| from rag.parser.pdf_parser import HuParser | |||
| @@ -9,6 +12,10 @@ from rag.utils import num_tokens_from_string | |||
| class Pdf(HuParser): | |||
| def __init__(self): | |||
| self.model_speciess = ParserType.PAPER.value | |||
| super().__init__() | |||
| def __call__(self, filename, binary=None, from_page=0, | |||
| to_page=100000, zoomin=3, callback=None): | |||
| self.__images__( | |||
| @@ -63,6 +70,15 @@ class Pdf(HuParser): | |||
| "[0-9. 一、i]*(introduction|abstract|摘要|引言|keywords|key words|关键词|background|背景|目录|前言|contents)", | |||
| txt.lower().strip()) | |||
| if from_page > 0: | |||
| return { | |||
| "title":"", | |||
| "authors": "", | |||
| "abstract": "", | |||
| "lines": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes[i:] if | |||
| re.match(r"(text|title)", b.get("layoutno", "text"))], | |||
| "tables": tbls | |||
| } | |||
| # get title and authors | |||
| title = "" | |||
| authors = [] | |||
| @@ -115,18 +131,13 @@ class Pdf(HuParser): | |||
| def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs): | |||
| pdf_parser = None | |||
| paper = {} | |||
| if re.search(r"\.pdf$", filename, re.IGNORECASE): | |||
| pdf_parser = Pdf() | |||
| paper = pdf_parser(filename if not binary else binary, | |||
| from_page=from_page, to_page=to_page, callback=callback) | |||
| else: raise NotImplementedError("file type not supported yet(pdf supported)") | |||
| doc = { | |||
| "docnm_kwd": paper["title"] if paper["title"] else filename, | |||
| "authors_tks": paper["authors"] | |||
| } | |||
| doc["title_tks"] = huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", doc["docnm_kwd"])) | |||
| doc = {"docnm_kwd": filename, "authors_tks": paper["authors"], | |||
| "title_tks": huqie.qie(paper["title"] if paper["title"] else filename)} | |||
| doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) | |||
| doc["authors_sm_tks"] = huqie.qieqie(doc["authors_tks"]) | |||
| # is it English | |||
| @@ -3,7 +3,7 @@ import re | |||
| from io import BytesIO | |||
| from nltk import word_tokenize | |||
| from openpyxl import load_workbook | |||
| from rag.parser import is_english | |||
| from rag.parser import is_english, random_choices | |||
| from rag.nlp import huqie, stemmer | |||
| @@ -33,9 +33,9 @@ class Excel(object): | |||
| if len(res) % 999 == 0: | |||
| callback(len(res)*0.6/total, ("Extract Q&A: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..."%(",".join(fails[:3])) if fails else ""))) | |||
| callback(0.6, ("Extract Q&A: {}".format(len(res)) + ( | |||
| callback(0.6, ("Extract Q&A: {}. ".format(len(res)) + ( | |||
| f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) | |||
| self.is_english = is_english([rmPrefix(q) for q, _ in random.choices(res, k=30) if len(q)>1]) | |||
| self.is_english = is_english([rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q)>1]) | |||
| return res | |||
| @@ -0,0 +1,170 @@ | |||
| import copy | |||
| import random | |||
| import re | |||
| from io import BytesIO | |||
| from xpinyin import Pinyin | |||
| import numpy as np | |||
| import pandas as pd | |||
| from nltk import word_tokenize | |||
| from openpyxl import load_workbook | |||
| from dateutil.parser import parse as datetime_parse | |||
| from rag.parser import is_english, tokenize | |||
| from rag.nlp import huqie, stemmer | |||
| class Excel(object): | |||
| def __call__(self, fnm, binary=None, callback=None): | |||
| if not binary: | |||
| wb = load_workbook(fnm) | |||
| else: | |||
| wb = load_workbook(BytesIO(binary)) | |||
| total = 0 | |||
| for sheetname in wb.sheetnames: | |||
| total += len(list(wb[sheetname].rows)) | |||
| res, fails, done = [], [], 0 | |||
| for sheetname in wb.sheetnames: | |||
| ws = wb[sheetname] | |||
| rows = list(ws.rows) | |||
| headers = [cell.value for cell in rows[0]] | |||
| missed = set([i for i,h in enumerate(headers) if h is None]) | |||
| headers = [cell.value for i,cell in enumerate(rows[0]) if i not in missed] | |||
| data = [] | |||
| for i, r in enumerate(rows[1:]): | |||
| row = [cell.value for ii,cell in enumerate(r) if ii not in missed] | |||
| if len(row) != len(headers): | |||
| fails.append(str(i)) | |||
| continue | |||
| data.append(row) | |||
| done += 1 | |||
| if done % 999 == 0: | |||
| callback(done * 0.6/total, ("Extract records: {}".format(len(res)) + (f"{len(fails)} failure({sheetname}), line: %s..."%(",".join(fails[:3])) if fails else ""))) | |||
| res.append(pd.DataFrame(np.array(data), columns=headers)) | |||
| callback(0.6, ("Extract records: {}. ".format(done) + ( | |||
| f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) | |||
| return res | |||
| def trans_datatime(s): | |||
| try: | |||
| return datetime_parse(s.strip()).strftime("%Y-%m-%dT%H:%M:%S") | |||
| except Exception as e: | |||
| pass | |||
| def trans_bool(s): | |||
| if re.match(r"(true|yes|是)$", str(s).strip(), flags=re.IGNORECASE): return ["yes", "是"] | |||
| if re.match(r"(false|no|否)$", str(s).strip(), flags=re.IGNORECASE): return ["no", "否"] | |||
| def column_data_type(arr): | |||
| uni = len(set([a for a in arr if a is not None])) | |||
| counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0} | |||
| trans = {t:f for f,t in [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]} | |||
| for a in arr: | |||
| if a is None:continue | |||
| if re.match(r"[+-]?[0-9]+(\.0+)?$", str(a).replace("%%", "")): | |||
| counts["int"] += 1 | |||
| elif re.match(r"[+-]?[0-9.]+$", str(a).replace("%%", "")): | |||
| counts["float"] += 1 | |||
| elif re.match(r"(true|false|yes|no|是|否)$", str(a), flags=re.IGNORECASE): | |||
| counts["bool"] += 1 | |||
| elif trans_datatime(str(a)): | |||
| counts["datetime"] += 1 | |||
| else: counts["text"] += 1 | |||
| counts = sorted(counts.items(), key=lambda x: x[1]*-1) | |||
| ty = counts[0][0] | |||
| for i in range(len(arr)): | |||
| if arr[i] is None:continue | |||
| try: | |||
| arr[i] = trans[ty](str(arr[i])) | |||
| except Exception as e: | |||
| arr[i] = None | |||
| if ty == "text": | |||
| if len(arr) > 128 and uni/len(arr) < 0.1: | |||
| ty = "keyword" | |||
| return arr, ty | |||
| def chunk(filename, binary=None, callback=None, **kwargs): | |||
| dfs = [] | |||
| if re.search(r"\.xlsx?$", filename, re.IGNORECASE): | |||
| callback(0.1, "Start to parse.") | |||
| excel_parser = Excel() | |||
| dfs = excel_parser(filename, binary, callback) | |||
| elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE): | |||
| callback(0.1, "Start to parse.") | |||
| txt = "" | |||
| if binary: | |||
| txt = binary.decode("utf-8") | |||
| else: | |||
| with open(filename, "r") as f: | |||
| while True: | |||
| l = f.readline() | |||
| if not l: break | |||
| txt += l | |||
| lines = txt.split("\n") | |||
| fails = [] | |||
| headers = lines[0].split(kwargs.get("delimiter", "\t")) | |||
| rows = [] | |||
| for i, line in enumerate(lines[1:]): | |||
| row = [l for l in line.split(kwargs.get("delimiter", "\t"))] | |||
| if len(row) != len(headers): | |||
| fails.append(str(i)) | |||
| continue | |||
| rows.append(row) | |||
| if len(rows) % 999 == 0: | |||
| callback(len(rows) * 0.6 / len(lines), ("Extract records: {}".format(len(rows)) + ( | |||
| f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) | |||
| callback(0.6, ("Extract records: {}".format(len(rows)) + ( | |||
| f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) | |||
| dfs = [pd.DataFrame(np.array(rows), columns=headers)] | |||
| else: raise NotImplementedError("file type not supported yet(excel, text, csv supported)") | |||
| res = [] | |||
| PY = Pinyin() | |||
| fieds_map = {"text": "_tks", "int": "_int", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"} | |||
| for df in dfs: | |||
| for n in ["id", "_id", "index", "idx"]: | |||
| if n in df.columns:del df[n] | |||
| clmns = df.columns.values | |||
| txts = list(copy.deepcopy(clmns)) | |||
| py_clmns = [PY.get_pinyins(n)[0].replace("-", "_") for n in clmns] | |||
| clmn_tys = [] | |||
| for j in range(len(clmns)): | |||
| cln,ty = column_data_type(df[clmns[j]]) | |||
| clmn_tys.append(ty) | |||
| df[clmns[j]] = cln | |||
| if ty == "text": txts.extend([str(c) for c in cln if c]) | |||
| clmns_map = [(py_clmns[j] + fieds_map[clmn_tys[j]], clmns[j]) for i in range(len(clmns))] | |||
| # TODO: set this column map to KB parser configuration | |||
| eng = is_english(txts) | |||
| for ii,row in df.iterrows(): | |||
| d = {} | |||
| row_txt = [] | |||
| for j in range(len(clmns)): | |||
| if row[clmns[j]] is None:continue | |||
| fld = clmns_map[j][0] | |||
| d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else huqie.qie(row[clmns[j]]) | |||
| row_txt.append("{}:{}".format(clmns[j], row[clmns[j]])) | |||
| if not row_txt:continue | |||
| tokenize(d, "; ".join(row_txt), eng) | |||
| print(d) | |||
| res.append(d) | |||
| callback(0.6, "") | |||
| return res | |||
| if __name__== "__main__": | |||
| import sys | |||
| def dummy(a, b): | |||
| pass | |||
| chunk(sys.argv[1], callback=dummy) | |||
| @@ -67,7 +67,7 @@ class Dealer: | |||
| ps = int(req.get("size", 1000)) | |||
| src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", | |||
| "image_id", "doc_id", "q_512_vec", "q_768_vec", | |||
| "q_1024_vec", "q_1536_vec", "available_int"]) | |||
| "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"]) | |||
| s = s.query(bqry)[pg * ps:(pg + 1) * ps] | |||
| s = s.highlight("content_ltks") | |||
| @@ -234,7 +234,7 @@ class Dealer: | |||
| sres.field[i].get("q_%d_vec" % len(sres.query_vector), "\t".join(["0"] * len(sres.query_vector)))) for i in sres.ids] | |||
| if not ins_embd: | |||
| return [], [], [] | |||
| ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ") | |||
| ins_tw = [sres.field[i][cfield].split(" ") | |||
| for i in sres.ids] | |||
| sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector, | |||
| ins_embd, | |||
| @@ -281,6 +281,7 @@ class Dealer: | |||
| d = { | |||
| "chunk_id": id, | |||
| "content_ltks": sres.field[id]["content_ltks"], | |||
| "content_with_weight": sres.field[id]["content_with_weight"], | |||
| "doc_id": sres.field[id]["doc_id"], | |||
| "docnm_kwd": dnm, | |||
| "kb_id": sres.field[id]["kb_id"], | |||
| @@ -1,4 +1,5 @@ | |||
| import copy | |||
| import random | |||
| from .pdf_parser import HuParser as PdfParser | |||
| from .docx_parser import HuDocxParser as DocxParser | |||
| @@ -38,6 +39,9 @@ BULLET_PATTERN = [[ | |||
| ] | |||
| ] | |||
| def random_choices(arr, k): | |||
| k = min(len(arr), k) | |||
| return random.choices(arr, k=k) | |||
| def bullets_category(sections): | |||
| global BULLET_PATTERN | |||
| @@ -1,7 +1,10 @@ | |||
| # -*- coding: utf-8 -*- | |||
| import os | |||
| import random | |||
| from functools import partial | |||
| import fitz | |||
| import requests | |||
| import xgboost as xgb | |||
| from io import BytesIO | |||
| import torch | |||
| @@ -10,13 +13,14 @@ import pdfplumber | |||
| import logging | |||
| from PIL import Image | |||
| import numpy as np | |||
| from api.db import ParserType | |||
| from rag.nlp import huqie | |||
| from collections import Counter | |||
| from copy import deepcopy | |||
| from rag.cv.table_recognize import TableTransformer | |||
| from rag.cv.ppdetection import PPDet | |||
| from huggingface_hub import hf_hub_download | |||
| logging.getLogger("pdfminer").setLevel(logging.WARNING) | |||
| @@ -25,8 +29,10 @@ class HuParser: | |||
| from paddleocr import PaddleOCR | |||
| logging.getLogger("ppocr").setLevel(logging.ERROR) | |||
| self.ocr = PaddleOCR(use_angle_cls=False, lang="ch") | |||
| self.layouter = PPDet("/data/newpeak/medical-gpt/res/ppdet") | |||
| self.tbl_det = PPDet("/data/newpeak/medical-gpt/res/ppdet.tbl") | |||
| if not hasattr(self, "model_speciess"): | |||
| self.model_speciess = ParserType.GENERAL.value | |||
| self.layouter = partial(self.__remote_call, self.model_speciess) | |||
| self.tbl_det = partial(self.__remote_call, "table_component") | |||
| self.updown_cnt_mdl = xgb.Booster() | |||
| if torch.cuda.is_available(): | |||
| @@ -45,6 +51,38 @@ class HuParser: | |||
| """ | |||
| def __remote_call(self, species, images, thr=0.7): | |||
| url = os.environ.get("INFINIFLOW_SERVER") | |||
| if not url:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_SERVER'") | |||
| token = os.environ.get("INFINIFLOW_TOKEN") | |||
| if not token:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_TOKEN'") | |||
| def convert_image_to_bytes(PILimage): | |||
| image = BytesIO() | |||
| PILimage.save(image, format='png') | |||
| image.seek(0) | |||
| return image.getvalue() | |||
| images = [convert_image_to_bytes(img) for img in images] | |||
| def remote_call(): | |||
| nonlocal images, thr | |||
| res = requests.post(url+"/v1/layout/detect/"+species, files=[("image", img) for img in images], data={"threashold": thr}, | |||
| headers={"Authorization": token}, timeout=len(images) * 10) | |||
| res = res.json() | |||
| if res["retcode"] != 0: raise RuntimeError(res["retmsg"]) | |||
| return res["data"] | |||
| for _ in range(3): | |||
| try: | |||
| return remote_call() | |||
| except RuntimeError as e: | |||
| raise e | |||
| except Exception as e: | |||
| logging.error("layout_predict:"+str(e)) | |||
| return remote_call() | |||
| def __char_width(self, c): | |||
| return (c["x1"] - c["x0"]) // len(c["text"]) | |||
| @@ -344,7 +382,7 @@ class HuParser: | |||
| return layouts | |||
| def __table_paddle(self, images): | |||
| tbls = self.tbl_det([np.array(img) for img in images], thr=0.5) | |||
| tbls = self.tbl_det(images, thr=0.5) | |||
| res = [] | |||
| # align left&right for rows, align top&bottom for columns | |||
| for tbl in tbls: | |||
| @@ -522,7 +560,7 @@ class HuParser: | |||
| assert len(self.page_images) == len(self.boxes) | |||
| # Tag layout type | |||
| boxes = [] | |||
| layouts = self.layouter([np.array(img) for img in self.page_images]) | |||
| layouts = self.layouter(self.page_images) | |||
| assert len(self.page_images) == len(layouts) | |||
| for pn, lts in enumerate(layouts): | |||
| bxs = self.boxes[pn] | |||
| @@ -1705,7 +1743,8 @@ class HuParser: | |||
| self.__ocr_paddle(i + 1, img, chars, zoomin) | |||
| if not self.is_english and not any([c for c in self.page_chars]) and self.boxes: | |||
| self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join([b["text"] for b in random.choices([b for bxs in self.boxes for b in bxs], k=30)])) | |||
| bxes = [b for bxs in self.boxes for b in bxs] | |||
| self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join([b["text"] for b in random.choices(bxes, k=min(30, len(bxes)))])) | |||
| logging.info("Is it English:", self.is_english) | |||
| @@ -134,5 +134,5 @@ if __name__ == "__main__": | |||
| while True: | |||
| dispatch() | |||
| time.sleep(3) | |||
| time.sleep(1) | |||
| update_progress() | |||
| @@ -36,7 +36,7 @@ from rag.nlp import search | |||
| from io import BytesIO | |||
| import pandas as pd | |||
| from rag.app import laws, paper, presentation, manual, qa | |||
| from rag.app import laws, paper, presentation, manual, qa, table,book | |||
| from api.db import LLMType, ParserType | |||
| from api.db.services.document_service import DocumentService | |||
| @@ -49,10 +49,12 @@ BATCH_SIZE = 64 | |||
| FACTORY = { | |||
| ParserType.GENERAL.value: laws, | |||
| ParserType.PAPER.value: paper, | |||
| ParserType.BOOK.value: book, | |||
| ParserType.PRESENTATION.value: presentation, | |||
| ParserType.MANUAL.value: manual, | |||
| ParserType.LAWS.value: laws, | |||
| ParserType.QA.value: qa, | |||
| ParserType.TABLE.value: table, | |||
| } | |||
| @@ -66,7 +68,7 @@ def set_progress(task_id, from_page, to_page, prog=None, msg="Processing..."): | |||
| d = {"progress_msg": msg} | |||
| if prog is not None: d["progress"] = prog | |||
| try: | |||
| TaskService.update_by_id(task_id, d) | |||
| TaskService.update_progress(task_id, d) | |||
| except Exception as e: | |||
| cron_logger.error("set_progress:({}), {}".format(task_id, str(e))) | |||
| @@ -113,7 +115,7 @@ def build(row, cvmdl): | |||
| return [] | |||
| callback = partial(set_progress, row["id"], row["from_page"], row["to_page"]) | |||
| chunker = FACTORY[row["parser_id"]] | |||
| chunker = FACTORY[row["parser_id"].lower()] | |||
| try: | |||
| cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"])) | |||
| cks = chunker.chunk(row["name"], MINIO.get(row["kb_id"], row["location"]), row["from_page"], row["to_page"], | |||
| @@ -154,6 +156,7 @@ def build(row, cvmdl): | |||
| MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue()) | |||
| d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"]) | |||
| del d["image"] | |||
| docs.append(d) | |||
| return docs | |||
| @@ -168,7 +171,7 @@ def init_kb(row): | |||
| def embedding(docs, mdl): | |||
| tts, cnts = [d["docnm_kwd"] for d in docs if d.get("docnm_kwd")], [d["content_with_weight"] for d in docs] | |||
| tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [d["content_with_weight"] for d in docs] | |||
| tk_count = 0 | |||
| if len(tts) == len(cnts): | |||
| tts, c = mdl.encode(tts) | |||
| @@ -207,6 +210,7 @@ def main(comm, mod): | |||
| cks = build(r, cv_mdl) | |||
| if not cks: | |||
| tmf.write(str(r["update_time"]) + "\n") | |||
| callback(1., "No chunk! Done!") | |||
| continue | |||
| # TODO: exception handler | |||
| ## set_progress(r["did"], -1, "ERROR: ") | |||
| @@ -215,7 +219,6 @@ def main(comm, mod): | |||
| except Exception as e: | |||
| callback(-1, "Embedding error:{}".format(str(e))) | |||
| cron_logger.error(str(e)) | |||
| continue | |||
| callback(msg="Finished embedding! Start to build index!") | |||
| init_kb(r) | |||
| @@ -227,6 +230,7 @@ def main(comm, mod): | |||
| else: | |||
| if TaskService.do_cancel(r["id"]): | |||
| ELASTICSEARCH.deleteByQuery(Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"])) | |||
| continue | |||
| callback(1., "Done!") | |||
| DocumentService.increment_chunk_num(r["doc_id"], r["kb_id"], tk_count, chunk_count, 0) | |||
| cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks))) | |||