| from flask_session import Session | from flask_session import Session | ||||
| from flask_login import LoginManager | from flask_login import LoginManager | ||||
| from api.settings import RetCode, SECRET_KEY, stat_logger | from api.settings import RetCode, SECRET_KEY, stat_logger | ||||
| from api.hook import HookManager | |||||
| from api.hook.common.parameters import AuthenticationParameters, ClientAuthenticationParameters | |||||
| from api.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger | from api.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger | ||||
| from api.utils.api_utils import get_json_result, server_error_response | from api.utils.api_utils import get_json_result, server_error_response | ||||
| from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer | from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer | ||||
| ] | ] | ||||
| def client_authentication_before_request(): | |||||
| result = HookManager.client_authentication(ClientAuthenticationParameters( | |||||
| request.full_path, request.headers, | |||||
| request.form, request.data, request.json, | |||||
| )) | |||||
| if result.code != RetCode.SUCCESS: | |||||
| return get_json_result(result.code, result.message) | |||||
| def site_authentication_before_request(): | |||||
| for url_prefix in client_urls_prefix: | |||||
| if request.path.startswith(url_prefix): | |||||
| return | |||||
| result = HookManager.site_authentication(AuthenticationParameters( | |||||
| request.headers.get('site_signature'), | |||||
| request.json, | |||||
| )) | |||||
| if result.code != RetCode.SUCCESS: | |||||
| return get_json_result(result.code, result.message) | |||||
| @app.before_request | |||||
| def authentication_before_request(): | |||||
| if CLIENT_AUTHENTICATION: | |||||
| return client_authentication_before_request() | |||||
| if SITE_AUTHENTICATION: | |||||
| return site_authentication_before_request() | |||||
| @login_manager.request_loader | @login_manager.request_loader | ||||
| def load_user(web_request): | def load_user(web_request): |
| for id in sres.ids: | for id in sres.ids: | ||||
| d = { | d = { | ||||
| "chunk_id": id, | "chunk_id": id, | ||||
| "content_ltks": rmSpace(sres.highlight[id]) if question else sres.field[id]["content_ltks"], | |||||
| "content_with_weight": rmSpace(sres.highlight[id]) if question else sres.field[id]["content_with_weight"], | |||||
| "doc_id": sres.field[id]["doc_id"], | "doc_id": sres.field[id]["doc_id"], | ||||
| "docnm_kwd": sres.field[id]["docnm_kwd"], | "docnm_kwd": sres.field[id]["docnm_kwd"], | ||||
| "important_kwd": sres.field[id].get("important_kwd", []), | "important_kwd": sres.field[id].get("important_kwd", []), | ||||
| q, a = rmPrefix(arr[0]), rmPrefix[arr[1]] | q, a = rmPrefix(arr[0]), rmPrefix[arr[1]] | ||||
| d = beAdoc(d, arr[0], arr[1], not any([huqie.is_chinese(t) for t in q+a])) | d = beAdoc(d, arr[0], arr[1], not any([huqie.is_chinese(t) for t in q+a])) | ||||
| v, c = embd_mdl.encode([doc.name, req["content_ltks"]]) | |||||
| v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) | |||||
| v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] | v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] | ||||
| d["q_%d_vec" % len(v)] = v.tolist() | d["q_%d_vec" % len(v)] = v.tolist() | ||||
| ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) | ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) | ||||
| @manager.route('/create', methods=['POST']) | @manager.route('/create', methods=['POST']) | ||||
| @login_required | @login_required | ||||
| @validate_request("doc_id", "content_ltks") | |||||
| @validate_request("doc_id", "content_with_weight") | |||||
| def create(): | def create(): | ||||
| req = request.json | req = request.json | ||||
| md5 = hashlib.md5() | md5 = hashlib.md5() | ||||
| md5.update((req["content_ltks"] + req["doc_id"]).encode("utf-8")) | |||||
| md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8")) | |||||
| chunck_id = md5.hexdigest() | chunck_id = md5.hexdigest() | ||||
| d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_ltks"])} | |||||
| d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_with_weight"])} | |||||
| d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) | d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"]) | ||||
| d["important_kwd"] = req.get("important_kwd", []) | d["important_kwd"] = req.get("important_kwd", []) | ||||
| d["important_tks"] = huqie.qie(" ".join(req.get("important_kwd", []))) | d["important_tks"] = huqie.qie(" ".join(req.get("important_kwd", []))) | ||||
| embd_mdl = TenantLLMService.model_instance( | embd_mdl = TenantLLMService.model_instance( | ||||
| tenant_id, LLMType.EMBEDDING.value) | tenant_id, LLMType.EMBEDDING.value) | ||||
| v, c = embd_mdl.encode([doc.name, req["content_ltks"]]) | |||||
| v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) | |||||
| DocumentService.increment_chunk_num(req["doc_id"], doc.kb_id, c, 1, 0) | DocumentService.increment_chunk_num(req["doc_id"], doc.kb_id, c, 1, 0) | ||||
| v = 0.1 * v[0] + 0.9 * v[1] | v = 0.1 * v[0] + 0.9 * v[1] | ||||
| d["q_%d_vec" % len(v)] = v.tolist() | d["q_%d_vec" % len(v)] = v.tolist() |
| chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) | chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) | ||||
| kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold, | kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold, | ||||
| dialog.vector_similarity_weight, top=1024, aggs=False) | dialog.vector_similarity_weight, top=1024, aggs=False) | ||||
| knowledges = [ck["content_ltks"] for ck in kbinfos["chunks"]] | |||||
| knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] | |||||
| if not knowledges and prompt_config["empty_response"]: | if not knowledges and prompt_config["empty_response"]: | ||||
| return {"answer": prompt_config["empty_response"], "retrieval": kbinfos} | return {"answer": prompt_config["empty_response"], "retrieval": kbinfos} |
| "id": get_uuid(), | "id": get_uuid(), | ||||
| "kb_id": kb.id, | "kb_id": kb.id, | ||||
| "parser_id": kb.parser_id, | "parser_id": kb.parser_id, | ||||
| "parser_config": kb.parser_config, | |||||
| "created_by": current_user.id, | "created_by": current_user.id, | ||||
| "type": filename_type(filename), | "type": filename_type(filename), | ||||
| "name": filename, | "name": filename, | ||||
| "id": get_uuid(), | "id": get_uuid(), | ||||
| "kb_id": kb.id, | "kb_id": kb.id, | ||||
| "parser_id": kb.parser_id, | "parser_id": kb.parser_id, | ||||
| "parser_config": kb.parser_config, | |||||
| "created_by": current_user.id, | "created_by": current_user.id, | ||||
| "type": FileType.VIRTUAL, | "type": FileType.VIRTUAL, | ||||
| "name": req["name"], | "name": req["name"], | ||||
| data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR) | data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR) | ||||
| keywords = request.args.get("keywords", "") | keywords = request.args.get("keywords", "") | ||||
| page_number = request.args.get("page", 1) | |||||
| items_per_page = request.args.get("page_size", 15) | |||||
| page_number = int(request.args.get("page", 1)) | |||||
| items_per_page = int(request.args.get("page_size", 15)) | |||||
| orderby = request.args.get("orderby", "create_time") | orderby = request.args.get("orderby", "create_time") | ||||
| desc = request.args.get("desc", True) | desc = request.args.get("desc", True) | ||||
| try: | try: | ||||
| req = request.json | req = request.json | ||||
| try: | try: | ||||
| for id in req["doc_ids"]: | for id in req["doc_ids"]: | ||||
| DocumentService.update_by_id(id, {"run": str(req["run"]), "progress": 0}) | |||||
| info = {"run": str(req["run"]), "progress": 0} | |||||
| if str(req["run"]) == TaskStatus.RUNNING.value:info["progress_msg"] = "" | |||||
| DocumentService.update_by_id(id, info) | |||||
| if str(req["run"]) == TaskStatus.CANCEL.value: | if str(req["run"]) == TaskStatus.CANCEL.value: | ||||
| tenant_id = DocumentService.get_tenant_id(id) | tenant_id = DocumentService.get_tenant_id(id) | ||||
| if not tenant_id: | if not tenant_id: |
| @manager.route('/create', methods=['post']) | @manager.route('/create', methods=['post']) | ||||
| @login_required | @login_required | ||||
| @validate_request("name", "description", "permission", "parser_id") | |||||
| @validate_request("name") | |||||
| def create(): | def create(): | ||||
| req = request.json | req = request.json | ||||
| req["name"] = req["name"].strip() | req["name"] = req["name"].strip() |
| RESUME = "resume" | RESUME = "resume" | ||||
| BOOK = "book" | BOOK = "book" | ||||
| QA = "qa" | QA = "qa" | ||||
| TABLE = "table" |
| ) | ) | ||||
| from playhouse.pool import PooledMySQLDatabase | from playhouse.pool import PooledMySQLDatabase | ||||
| from api.db import SerializedType | |||||
| from api.db import SerializedType, ParserType | |||||
| from api.settings import DATABASE, stat_logger, SECRET_KEY | from api.settings import DATABASE, stat_logger, SECRET_KEY | ||||
| from api.utils.log_utils import getLogger | from api.utils.log_utils import getLogger | ||||
| from api import utils | from api import utils | ||||
| embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID") | embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID") | ||||
| asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID") | asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID") | ||||
| img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID") | img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID") | ||||
| parser_ids = CharField(max_length=128, null=False, help_text="default image to text model ID") | |||||
| parser_ids = CharField(max_length=128, null=False, help_text="document processors") | |||||
| credit = IntegerField(default=512) | |||||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | ||||
| class Meta: | class Meta: | ||||
| similarity_threshold = FloatField(default=0.2) | similarity_threshold = FloatField(default=0.2) | ||||
| vector_similarity_weight = FloatField(default=0.3) | vector_similarity_weight = FloatField(default=0.3) | ||||
| parser_id = CharField(max_length=32, null=False, help_text="default parser ID") | |||||
| parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.GENERAL.value) | |||||
| parser_config = JSONField(null=False, default={"from_page":0, "to_page": 100000}) | |||||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | ||||
| def __str__(self): | def __str__(self): | ||||
| thumbnail = TextField(null=True, help_text="thumbnail base64 string") | thumbnail = TextField(null=True, help_text="thumbnail base64 string") | ||||
| kb_id = CharField(max_length=256, null=False, index=True) | kb_id = CharField(max_length=256, null=False, index=True) | ||||
| parser_id = CharField(max_length=32, null=False, help_text="default parser ID") | parser_id = CharField(max_length=32, null=False, help_text="default parser ID") | ||||
| parser_config = JSONField(null=False, default={"from_page":0, "to_page": 100000}) | |||||
| source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from") | source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from") | ||||
| type = CharField(max_length=32, null=False, help_text="file extension") | type = CharField(max_length=32, null=False, help_text="file extension") | ||||
| created_by = CharField(max_length=32, null=False, help_text="who created it") | created_by = CharField(max_length=32, null=False, help_text="who created it") |
| # | |||||
| # Copyright 2021 The InfiniFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import abc | |||||
| import json | |||||
| import time | |||||
| from functools import wraps | |||||
| from shortuuid import ShortUUID | |||||
| from api.versions import get_rag_version | |||||
| from api.errors.error_services import * | |||||
| from api.settings import ( | |||||
| GRPC_PORT, HOST, HTTP_PORT, | |||||
| RANDOM_INSTANCE_ID, stat_logger, | |||||
| ) | |||||
| instance_id = ShortUUID().random(length=8) if RANDOM_INSTANCE_ID else f'flow-{HOST}-{HTTP_PORT}' | |||||
| server_instance = ( | |||||
| f'{HOST}:{GRPC_PORT}', | |||||
| json.dumps({ | |||||
| 'instance_id': instance_id, | |||||
| 'timestamp': round(time.time() * 1000), | |||||
| 'version': get_rag_version() or '', | |||||
| 'host': HOST, | |||||
| 'grpc_port': GRPC_PORT, | |||||
| 'http_port': HTTP_PORT, | |||||
| }), | |||||
| ) | |||||
| def check_service_supported(method): | |||||
| """Decorator to check if `service_name` is supported. | |||||
| The attribute `supported_services` MUST be defined in class. | |||||
| The first and second arguments of `method` MUST be `self` and `service_name`. | |||||
| :param Callable method: The class method. | |||||
| :return: The inner wrapper function. | |||||
| :rtype: Callable | |||||
| """ | |||||
| @wraps(method) | |||||
| def magic(self, service_name, *args, **kwargs): | |||||
| if service_name not in self.supported_services: | |||||
| raise ServiceNotSupported(service_name=service_name) | |||||
| return method(self, service_name, *args, **kwargs) | |||||
| return magic | |||||
| class ServicesDB(abc.ABC): | |||||
| """Database for storage service urls. | |||||
| Abstract base class for the real backends. | |||||
| """ | |||||
| @property | |||||
| @abc.abstractmethod | |||||
| def supported_services(self): | |||||
| """The names of supported services. | |||||
| The returned list SHOULD contain `ragflow` (model download) and `servings` (RAG-Serving). | |||||
| :return: The service names. | |||||
| :rtype: list | |||||
| """ | |||||
| pass | |||||
| @abc.abstractmethod | |||||
| def _get_serving(self): | |||||
| pass | |||||
| def get_serving(self): | |||||
| try: | |||||
| return self._get_serving() | |||||
| except ServicesError as e: | |||||
| stat_logger.exception(e) | |||||
| return [] | |||||
| @abc.abstractmethod | |||||
| def _insert(self, service_name, service_url, value=''): | |||||
| pass | |||||
| @check_service_supported | |||||
| def insert(self, service_name, service_url, value=''): | |||||
| """Insert a service url to database. | |||||
| :param str service_name: The service name. | |||||
| :param str service_url: The service url. | |||||
| :return: None | |||||
| """ | |||||
| try: | |||||
| self._insert(service_name, service_url, value) | |||||
| except ServicesError as e: | |||||
| stat_logger.exception(e) | |||||
| @abc.abstractmethod | |||||
| def _delete(self, service_name, service_url): | |||||
| pass | |||||
| @check_service_supported | |||||
| def delete(self, service_name, service_url): | |||||
| """Delete a service url from database. | |||||
| :param str service_name: The service name. | |||||
| :param str service_url: The service url. | |||||
| :return: None | |||||
| """ | |||||
| try: | |||||
| self._delete(service_name, service_url) | |||||
| except ServicesError as e: | |||||
| stat_logger.exception(e) | |||||
| def register_flow(self): | |||||
| """Call `self.insert` for insert the flow server address to databae. | |||||
| :return: None | |||||
| """ | |||||
| self.insert('flow-server', *server_instance) | |||||
| def unregister_flow(self): | |||||
| """Call `self.delete` for delete the flow server address from databae. | |||||
| :return: None | |||||
| """ | |||||
| self.delete('flow-server', server_instance[0]) | |||||
| @abc.abstractmethod | |||||
| def _get_urls(self, service_name, with_values=False): | |||||
| pass | |||||
| @check_service_supported | |||||
| def get_urls(self, service_name, with_values=False): | |||||
| """Query service urls from database. The urls may belong to other nodes. | |||||
| Currently, only `ragflow` (model download) urls and `servings` (RAG-Serving) urls are supported. | |||||
| `ragflow` is a url containing scheme, host, port and path, | |||||
| while `servings` only contains host and port. | |||||
| :param str service_name: The service name. | |||||
| :return: The service urls. | |||||
| :rtype: list | |||||
| """ | |||||
| try: | |||||
| return self._get_urls(service_name, with_values) | |||||
| except ServicesError as e: | |||||
| stat_logger.exception(e) | |||||
| return [] |
| @classmethod | @classmethod | ||||
| @DB.connection_context() | @DB.connection_context() | ||||
| def get_newly_uploaded(cls, tm, mod=0, comm=1, items_per_page=64): | def get_newly_uploaded(cls, tm, mod=0, comm=1, items_per_page=64): | ||||
| fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.type, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] | |||||
| fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.parser_config, cls.model.name, cls.model.type, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] | |||||
| docs = cls.model.select(*fields) \ | docs = cls.model.select(*fields) \ | ||||
| .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \ | .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \ | ||||
| .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\ | .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\ |
| cls.model.doc_num, | cls.model.doc_num, | ||||
| cls.model.token_num, | cls.model.token_num, | ||||
| cls.model.chunk_num, | cls.model.chunk_num, | ||||
| cls.model.parser_id] | |||||
| cls.model.parser_id, | |||||
| cls.model.parser_config] | |||||
| kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where( | kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where( | ||||
| (cls.model.id == kb_id), | (cls.model.id == kb_id), | ||||
| (cls.model.status == StatusEnum.VALID.value) | (cls.model.status == StatusEnum.VALID.value) |
| @classmethod | @classmethod | ||||
| @DB.connection_context() | @DB.connection_context() | ||||
| def get_tasks(cls, tm, mod=0, comm=1, items_per_page=64): | def get_tasks(cls, tm, mod=0, comm=1, items_per_page=64): | ||||
| fields = [cls.model.id, cls.model.doc_id, cls.model.from_page,cls.model.to_page, Document.kb_id, Document.parser_id, Document.name, Document.type, Document.location, Document.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] | |||||
| fields = [cls.model.id, cls.model.doc_id, cls.model.from_page,cls.model.to_page, Document.kb_id, Document.parser_id, Document.parser_config, Document.name, Document.type, Document.location, Document.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time] | |||||
| docs = cls.model.select(*fields) \ | docs = cls.model.select(*fields) \ | ||||
| .join(Document, on=(cls.model.doc_id == Document.id)) \ | .join(Document, on=(cls.model.doc_id == Document.id)) \ | ||||
| .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \ | .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \ | ||||
| except Exception as e: | except Exception as e: | ||||
| pass | pass | ||||
| return True | return True | ||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def update_progress(cls, id, info): | |||||
| cls.model.update(progress_msg=cls.model.progress_msg + "\n"+info["progress_msg"]).where( | |||||
| cls.model.id == id).execute() | |||||
| if "progress" in info: | |||||
| cls.model.update(progress=info["progress"]).where( | |||||
| cls.model.id == id).execute() |
| .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role==UserTenantRole.NORMAL.value)))\ | .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role==UserTenantRole.NORMAL.value)))\ | ||||
| .where(cls.model.status == StatusEnum.VALID.value).dicts()) | .where(cls.model.status == StatusEnum.VALID.value).dicts()) | ||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def decrease(cls, user_id, num): | |||||
| num = cls.model.update(credit=cls.model.credit - num).where( | |||||
| cls.model.id == user_id).execute() | |||||
| if num == 0: raise LookupError("Tenant not found which is supposed to be there") | |||||
| class UserTenantService(CommonService): | class UserTenantService(CommonService): | ||||
| model = UserTenant | model = UserTenant |
| from .general_error import * | |||||
| class RagFlowError(Exception): | |||||
| message = 'Unknown Rag Flow Error' | |||||
| def __init__(self, message=None, *args, **kwargs): | |||||
| message = str(message) if message is not None else self.message | |||||
| message = message.format(*args, **kwargs) | |||||
| super().__init__(message) |
| from api.errors import RagFlowError | |||||
| __all__ = ['ServicesError', 'ServiceNotSupported', 'ZooKeeperNotConfigured', | |||||
| 'MissingZooKeeperUsernameOrPassword', 'ZooKeeperBackendError'] | |||||
| class ServicesError(RagFlowError): | |||||
| message = 'Unknown services error' | |||||
| class ServiceNotSupported(ServicesError): | |||||
| message = 'The service {service_name} is not supported' | |||||
| # | |||||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| class ParameterError(Exception): | |||||
| pass | |||||
| class PassError(Exception): | |||||
| pass |
| import importlib | |||||
| from api.hook.common.parameters import SignatureParameters, AuthenticationParameters, \ | |||||
| SignatureReturn, AuthenticationReturn, PermissionReturn, ClientAuthenticationReturn, ClientAuthenticationParameters | |||||
| from api.settings import HOOK_MODULE, stat_logger,RetCode | |||||
| class HookManager: | |||||
| SITE_SIGNATURE = [] | |||||
| SITE_AUTHENTICATION = [] | |||||
| CLIENT_AUTHENTICATION = [] | |||||
| PERMISSION_CHECK = [] | |||||
| @staticmethod | |||||
| def init(): | |||||
| if HOOK_MODULE is not None: | |||||
| for modules in HOOK_MODULE.values(): | |||||
| for module in modules.split(";"): | |||||
| try: | |||||
| importlib.import_module(module) | |||||
| except Exception as e: | |||||
| stat_logger.exception(e) | |||||
| @staticmethod | |||||
| def register_site_signature_hook(func): | |||||
| HookManager.SITE_SIGNATURE.append(func) | |||||
| @staticmethod | |||||
| def register_site_authentication_hook(func): | |||||
| HookManager.SITE_AUTHENTICATION.append(func) | |||||
| @staticmethod | |||||
| def register_client_authentication_hook(func): | |||||
| HookManager.CLIENT_AUTHENTICATION.append(func) | |||||
| @staticmethod | |||||
| def register_permission_check_hook(func): | |||||
| HookManager.PERMISSION_CHECK.append(func) | |||||
| @staticmethod | |||||
| def client_authentication(parm: ClientAuthenticationParameters) -> ClientAuthenticationReturn: | |||||
| if HookManager.CLIENT_AUTHENTICATION: | |||||
| return HookManager.CLIENT_AUTHENTICATION[0](parm) | |||||
| return ClientAuthenticationReturn() | |||||
| @staticmethod | |||||
| def site_signature(parm: SignatureParameters) -> SignatureReturn: | |||||
| if HookManager.SITE_SIGNATURE: | |||||
| return HookManager.SITE_SIGNATURE[0](parm) | |||||
| return SignatureReturn() | |||||
| @staticmethod | |||||
| def site_authentication(parm: AuthenticationParameters) -> AuthenticationReturn: | |||||
| if HookManager.SITE_AUTHENTICATION: | |||||
| return HookManager.SITE_AUTHENTICATION[0](parm) | |||||
| return AuthenticationReturn() | |||||
| import requests | |||||
| from api.db.service_registry import ServiceRegistry | |||||
| from api.settings import RegistryServiceName | |||||
| from api.hook import HookManager | |||||
| from api.hook.common.parameters import ClientAuthenticationParameters, ClientAuthenticationReturn | |||||
| from api.settings import HOOK_SERVER_NAME | |||||
| @HookManager.register_client_authentication_hook | |||||
| def authentication(parm: ClientAuthenticationParameters) -> ClientAuthenticationReturn: | |||||
| service_list = ServiceRegistry.load_service( | |||||
| server_name=HOOK_SERVER_NAME, | |||||
| service_name=RegistryServiceName.CLIENT_AUTHENTICATION.value | |||||
| ) | |||||
| if not service_list: | |||||
| raise Exception(f"client authentication error: no found server" | |||||
| f" {HOOK_SERVER_NAME} service client_authentication") | |||||
| service = service_list[0] | |||||
| response = getattr(requests, service.f_method.lower(), None)( | |||||
| url=service.f_url, | |||||
| json=parm.to_dict() | |||||
| ) | |||||
| if response.status_code != 200: | |||||
| raise Exception( | |||||
| f"client authentication error: request authentication url failed, status code {response.status_code}") | |||||
| elif response.json().get("code") != 0: | |||||
| return ClientAuthenticationReturn(code=response.json().get("code"), message=response.json().get("msg")) | |||||
| return ClientAuthenticationReturn() |
| import requests | |||||
| from api.db.service_registry import ServiceRegistry | |||||
| from api.settings import RegistryServiceName | |||||
| from api.hook import HookManager | |||||
| from api.hook.common.parameters import PermissionCheckParameters, PermissionReturn | |||||
| from api.settings import HOOK_SERVER_NAME | |||||
| @HookManager.register_permission_check_hook | |||||
| def permission(parm: PermissionCheckParameters) -> PermissionReturn: | |||||
| service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, service_name=RegistryServiceName.PERMISSION_CHECK.value) | |||||
| if not service_list: | |||||
| raise Exception(f"permission check error: no found server {HOOK_SERVER_NAME} service permission") | |||||
| service = service_list[0] | |||||
| response = getattr(requests, service.f_method.lower(), None)( | |||||
| url=service.f_url, | |||||
| json=parm.to_dict() | |||||
| ) | |||||
| if response.status_code != 200: | |||||
| raise Exception( | |||||
| f"permission check error: request permission url failed, status code {response.status_code}") | |||||
| elif response.json().get("code") != 0: | |||||
| return PermissionReturn(code=response.json().get("code"), message=response.json().get("msg")) | |||||
| return PermissionReturn() |
| import requests | |||||
| from api.db.service_registry import ServiceRegistry | |||||
| from api.settings import RegistryServiceName | |||||
| from api.hook import HookManager | |||||
| from api.hook.common.parameters import SignatureParameters, AuthenticationParameters, AuthenticationReturn,\ | |||||
| SignatureReturn | |||||
| from api.settings import HOOK_SERVER_NAME, PARTY_ID | |||||
| @HookManager.register_site_signature_hook | |||||
| def signature(parm: SignatureParameters) -> SignatureReturn: | |||||
| service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, service_name=RegistryServiceName.SIGNATURE.value) | |||||
| if not service_list: | |||||
| raise Exception(f"signature error: no found server {HOOK_SERVER_NAME} service signature") | |||||
| service = service_list[0] | |||||
| response = getattr(requests, service.f_method.lower(), None)( | |||||
| url=service.f_url, | |||||
| json=parm.to_dict() | |||||
| ) | |||||
| if response.status_code == 200: | |||||
| if response.json().get("code") == 0: | |||||
| return SignatureReturn(site_signature=response.json().get("data")) | |||||
| else: | |||||
| raise Exception(f"signature error: request signature url failed, result: {response.json()}") | |||||
| else: | |||||
| raise Exception(f"signature error: request signature url failed, status code {response.status_code}") | |||||
| @HookManager.register_site_authentication_hook | |||||
| def authentication(parm: AuthenticationParameters) -> AuthenticationReturn: | |||||
| if not parm.src_party_id or str(parm.src_party_id) == "0": | |||||
| parm.src_party_id = PARTY_ID | |||||
| service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, | |||||
| service_name=RegistryServiceName.SITE_AUTHENTICATION.value) | |||||
| if not service_list: | |||||
| raise Exception( | |||||
| f"site authentication error: no found server {HOOK_SERVER_NAME} service site_authentication") | |||||
| service = service_list[0] | |||||
| response = getattr(requests, service.f_method.lower(), None)( | |||||
| url=service.f_url, | |||||
| json=parm.to_dict() | |||||
| ) | |||||
| if response.status_code != 200: | |||||
| raise Exception( | |||||
| f"site authentication error: request site_authentication url failed, status code {response.status_code}") | |||||
| elif response.json().get("code") != 0: | |||||
| return AuthenticationReturn(code=response.json().get("code"), message=response.json().get("msg")) | |||||
| return AuthenticationReturn() |
| from api.settings import RetCode | |||||
| class ParametersBase: | |||||
| def to_dict(self): | |||||
| d = {} | |||||
| for k, v in self.__dict__.items(): | |||||
| d[k] = v | |||||
| return d | |||||
| class ClientAuthenticationParameters(ParametersBase): | |||||
| def __init__(self, full_path, headers, form, data, json): | |||||
| self.full_path = full_path | |||||
| self.headers = headers | |||||
| self.form = form | |||||
| self.data = data | |||||
| self.json = json | |||||
| class ClientAuthenticationReturn(ParametersBase): | |||||
| def __init__(self, code=RetCode.SUCCESS, message="success"): | |||||
| self.code = code | |||||
| self.message = message | |||||
| class SignatureParameters(ParametersBase): | |||||
| def __init__(self, party_id, body): | |||||
| self.party_id = party_id | |||||
| self.body = body | |||||
| class SignatureReturn(ParametersBase): | |||||
| def __init__(self, code=RetCode.SUCCESS, site_signature=None): | |||||
| self.code = code | |||||
| self.site_signature = site_signature | |||||
| class AuthenticationParameters(ParametersBase): | |||||
| def __init__(self, site_signature, body): | |||||
| self.site_signature = site_signature | |||||
| self.body = body | |||||
| class AuthenticationReturn(ParametersBase): | |||||
| def __init__(self, code=RetCode.SUCCESS, message="success"): | |||||
| self.code = code | |||||
| self.message = message | |||||
| class PermissionReturn(ParametersBase): | |||||
| def __init__(self, code=RetCode.SUCCESS, message="success"): | |||||
| self.code = code | |||||
| self.message = message | |||||
| import signal | import signal | ||||
| import sys | import sys | ||||
| import traceback | import traceback | ||||
| from werkzeug.serving import run_simple | from werkzeug.serving import run_simple | ||||
| from api.apps import app | from api.apps import app | ||||
| from api.db.runtime_config import RuntimeConfig | from api.db.runtime_config import RuntimeConfig | ||||
| from api.hook import HookManager | |||||
| from api.settings import ( | from api.settings import ( | ||||
| HOST, HTTP_PORT, access_logger, database_logger, stat_logger, | HOST, HTTP_PORT, access_logger, database_logger, stat_logger, | ||||
| ) | ) | ||||
| RuntimeConfig.init_env() | RuntimeConfig.init_env() | ||||
| RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT) | RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT) | ||||
| HookManager.init() | |||||
| peewee_logger = logging.getLogger('peewee') | peewee_logger = logging.getLogger('peewee') | ||||
| peewee_logger.propagate = False | peewee_logger.propagate = False | ||||
| # rag_arch.common.log.ROpenHandler | # rag_arch.common.log.ROpenHandler |
| CHAT_MDL = LLM.get("chat_model", "gpt-3.5-turbo") | CHAT_MDL = LLM.get("chat_model", "gpt-3.5-turbo") | ||||
| EMBEDDING_MDL = LLM.get("embedding_model", "text-embedding-ada-002") | EMBEDDING_MDL = LLM.get("embedding_model", "text-embedding-ada-002") | ||||
| ASR_MDL = LLM.get("asr_model", "whisper-1") | ASR_MDL = LLM.get("asr_model", "whisper-1") | ||||
| PARSERS = LLM.get("parsers", "General,Resume,Laws,Product Instructions,Books,Paper,Q&A,Programming Code,Power Point,Research Report") | |||||
| PARSERS = LLM.get("parsers", "general:General,resume:esume,laws:Laws,manual:Manual,book:Book,paper:Paper,qa:Q&A,presentation:Presentation") | |||||
| IMAGE2TEXT_MDL = LLM.get("image2text_model", "gpt-4-vision-preview") | IMAGE2TEXT_MDL = LLM.get("image2text_model", "gpt-4-vision-preview") | ||||
| # distribution | # distribution |
| import re | import re | ||||
| import numpy as np | import numpy as np | ||||
| from rag.parser import bullets_category, BULLET_PATTERN, is_english, tokenize, remove_contents_table, \ | from rag.parser import bullets_category, BULLET_PATTERN, is_english, tokenize, remove_contents_table, \ | ||||
| hierarchical_merge, make_colon_as_title, naive_merge | |||||
| hierarchical_merge, make_colon_as_title, naive_merge, random_choices | |||||
| from rag.nlp import huqie | from rag.nlp import huqie | ||||
| from rag.parser.docx_parser import HuDocxParser | from rag.parser.docx_parser import HuDocxParser | ||||
| from rag.parser.pdf_parser import HuParser | from rag.parser.pdf_parser import HuParser | ||||
| doc_parser = HuDocxParser() | doc_parser = HuDocxParser() | ||||
| # TODO: table of contents need to be removed | # TODO: table of contents need to be removed | ||||
| sections, tbls = doc_parser(binary if binary else filename, from_page=from_page, to_page=to_page) | sections, tbls = doc_parser(binary if binary else filename, from_page=from_page, to_page=to_page) | ||||
| remove_contents_table(sections, eng=is_english(random.choices([t for t,_ in sections], k=200))) | |||||
| remove_contents_table(sections, eng=is_english(random_choices([t for t,_ in sections], k=200))) | |||||
| callback(0.8, "Finish parsing.") | callback(0.8, "Finish parsing.") | ||||
| elif re.search(r"\.pdf$", filename, re.IGNORECASE): | elif re.search(r"\.pdf$", filename, re.IGNORECASE): | ||||
| pdf_parser = Pdf() | pdf_parser = Pdf() | ||||
| l = f.readline() | l = f.readline() | ||||
| if not l:break | if not l:break | ||||
| txt += l | txt += l | ||||
| sections = txt.split("\n") | |||||
| sections = txt.split("\n") | |||||
| sections = [(l,"") for l in sections if l] | sections = [(l,"") for l in sections if l] | ||||
| remove_contents_table(sections, eng = is_english(random.choices([t for t,_ in sections], k=200))) | |||||
| remove_contents_table(sections, eng = is_english(random_choices([t for t,_ in sections], k=200))) | |||||
| callback(0.8, "Finish parsing.") | callback(0.8, "Finish parsing.") | ||||
| else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") | else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") | ||||
| make_colon_as_title(sections) | make_colon_as_title(sections) | ||||
| bull = bullets_category([t for t in random.choices([t for t,_ in sections], k=100)]) | |||||
| bull = bullets_category([t for t in random_choices([t for t,_ in sections], k=100)]) | |||||
| if bull >= 0: cks = hierarchical_merge(bull, sections, 3) | if bull >= 0: cks = hierarchical_merge(bull, sections, 3) | ||||
| else: cks = naive_merge(sections, kwargs.get("chunk_token_num", 256), kwargs.get("delimer", "\n。;!?")) | else: cks = naive_merge(sections, kwargs.get("chunk_token_num", 256), kwargs.get("delimer", "\n。;!?")) | ||||
| sections = [t for t, _ in sections] | sections = [t for t, _ in sections] | ||||
| # is it English | # is it English | ||||
| eng = is_english(random.choices(sections, k=218)) | |||||
| eng = is_english(random_choices(sections, k=218)) | |||||
| res = [] | res = [] | ||||
| # add tables | # add tables |
| l = f.readline() | l = f.readline() | ||||
| if not l:break | if not l:break | ||||
| txt += l | txt += l | ||||
| sections = txt.split("\n") | |||||
| sections = txt.split("\n") | |||||
| sections = txt.split("\n") | |||||
| sections = [l for l in sections if l] | sections = [l for l in sections if l] | ||||
| callback(0.8, "Finish parsing.") | callback(0.8, "Finish parsing.") | ||||
| else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") | else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") |
| l = f.readline() | l = f.readline() | ||||
| if not l:break | if not l:break | ||||
| txt += l | txt += l | ||||
| sections = txt.split("\n") | |||||
| sections = txt.split("\n") | |||||
| sections = [(l,"") for l in sections if l] | sections = [(l,"") for l in sections if l] | ||||
| callback(0.8, "Finish parsing.") | callback(0.8, "Finish parsing.") | ||||
| else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") | else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)") |
| import copy | import copy | ||||
| import re | import re | ||||
| from collections import Counter | from collections import Counter | ||||
| from api.db import ParserType | |||||
| from rag.cv.ppdetection import PPDet | |||||
| from rag.parser import tokenize | from rag.parser import tokenize | ||||
| from rag.nlp import huqie | from rag.nlp import huqie | ||||
| from rag.parser.pdf_parser import HuParser | from rag.parser.pdf_parser import HuParser | ||||
| class Pdf(HuParser): | class Pdf(HuParser): | ||||
| def __init__(self): | |||||
| self.model_speciess = ParserType.PAPER.value | |||||
| super().__init__() | |||||
| def __call__(self, filename, binary=None, from_page=0, | def __call__(self, filename, binary=None, from_page=0, | ||||
| to_page=100000, zoomin=3, callback=None): | to_page=100000, zoomin=3, callback=None): | ||||
| self.__images__( | self.__images__( | ||||
| "[0-9. 一、i]*(introduction|abstract|摘要|引言|keywords|key words|关键词|background|背景|目录|前言|contents)", | "[0-9. 一、i]*(introduction|abstract|摘要|引言|keywords|key words|关键词|background|背景|目录|前言|contents)", | ||||
| txt.lower().strip()) | txt.lower().strip()) | ||||
| if from_page > 0: | |||||
| return { | |||||
| "title":"", | |||||
| "authors": "", | |||||
| "abstract": "", | |||||
| "lines": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes[i:] if | |||||
| re.match(r"(text|title)", b.get("layoutno", "text"))], | |||||
| "tables": tbls | |||||
| } | |||||
| # get title and authors | # get title and authors | ||||
| title = "" | title = "" | ||||
| authors = [] | authors = [] | ||||
| def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs): | def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs): | ||||
| pdf_parser = None | pdf_parser = None | ||||
| paper = {} | |||||
| if re.search(r"\.pdf$", filename, re.IGNORECASE): | if re.search(r"\.pdf$", filename, re.IGNORECASE): | ||||
| pdf_parser = Pdf() | pdf_parser = Pdf() | ||||
| paper = pdf_parser(filename if not binary else binary, | paper = pdf_parser(filename if not binary else binary, | ||||
| from_page=from_page, to_page=to_page, callback=callback) | from_page=from_page, to_page=to_page, callback=callback) | ||||
| else: raise NotImplementedError("file type not supported yet(pdf supported)") | else: raise NotImplementedError("file type not supported yet(pdf supported)") | ||||
| doc = { | |||||
| "docnm_kwd": paper["title"] if paper["title"] else filename, | |||||
| "authors_tks": paper["authors"] | |||||
| } | |||||
| doc["title_tks"] = huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", doc["docnm_kwd"])) | |||||
| doc = {"docnm_kwd": filename, "authors_tks": paper["authors"], | |||||
| "title_tks": huqie.qie(paper["title"] if paper["title"] else filename)} | |||||
| doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) | doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) | ||||
| doc["authors_sm_tks"] = huqie.qieqie(doc["authors_tks"]) | doc["authors_sm_tks"] = huqie.qieqie(doc["authors_tks"]) | ||||
| # is it English | # is it English |
| from io import BytesIO | from io import BytesIO | ||||
| from nltk import word_tokenize | from nltk import word_tokenize | ||||
| from openpyxl import load_workbook | from openpyxl import load_workbook | ||||
| from rag.parser import is_english | |||||
| from rag.parser import is_english, random_choices | |||||
| from rag.nlp import huqie, stemmer | from rag.nlp import huqie, stemmer | ||||
| if len(res) % 999 == 0: | if len(res) % 999 == 0: | ||||
| callback(len(res)*0.6/total, ("Extract Q&A: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..."%(",".join(fails[:3])) if fails else ""))) | callback(len(res)*0.6/total, ("Extract Q&A: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..."%(",".join(fails[:3])) if fails else ""))) | ||||
| callback(0.6, ("Extract Q&A: {}".format(len(res)) + ( | |||||
| callback(0.6, ("Extract Q&A: {}. ".format(len(res)) + ( | |||||
| f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) | f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) | ||||
| self.is_english = is_english([rmPrefix(q) for q, _ in random.choices(res, k=30) if len(q)>1]) | |||||
| self.is_english = is_english([rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q)>1]) | |||||
| return res | return res | ||||
| import copy | |||||
| import random | |||||
| import re | |||||
| from io import BytesIO | |||||
| from xpinyin import Pinyin | |||||
| import numpy as np | |||||
| import pandas as pd | |||||
| from nltk import word_tokenize | |||||
| from openpyxl import load_workbook | |||||
| from dateutil.parser import parse as datetime_parse | |||||
| from rag.parser import is_english, tokenize | |||||
| from rag.nlp import huqie, stemmer | |||||
| class Excel(object): | |||||
| def __call__(self, fnm, binary=None, callback=None): | |||||
| if not binary: | |||||
| wb = load_workbook(fnm) | |||||
| else: | |||||
| wb = load_workbook(BytesIO(binary)) | |||||
| total = 0 | |||||
| for sheetname in wb.sheetnames: | |||||
| total += len(list(wb[sheetname].rows)) | |||||
| res, fails, done = [], [], 0 | |||||
| for sheetname in wb.sheetnames: | |||||
| ws = wb[sheetname] | |||||
| rows = list(ws.rows) | |||||
| headers = [cell.value for cell in rows[0]] | |||||
| missed = set([i for i,h in enumerate(headers) if h is None]) | |||||
| headers = [cell.value for i,cell in enumerate(rows[0]) if i not in missed] | |||||
| data = [] | |||||
| for i, r in enumerate(rows[1:]): | |||||
| row = [cell.value for ii,cell in enumerate(r) if ii not in missed] | |||||
| if len(row) != len(headers): | |||||
| fails.append(str(i)) | |||||
| continue | |||||
| data.append(row) | |||||
| done += 1 | |||||
| if done % 999 == 0: | |||||
| callback(done * 0.6/total, ("Extract records: {}".format(len(res)) + (f"{len(fails)} failure({sheetname}), line: %s..."%(",".join(fails[:3])) if fails else ""))) | |||||
| res.append(pd.DataFrame(np.array(data), columns=headers)) | |||||
| callback(0.6, ("Extract records: {}. ".format(done) + ( | |||||
| f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) | |||||
| return res | |||||
| def trans_datatime(s): | |||||
| try: | |||||
| return datetime_parse(s.strip()).strftime("%Y-%m-%dT%H:%M:%S") | |||||
| except Exception as e: | |||||
| pass | |||||
| def trans_bool(s): | |||||
| if re.match(r"(true|yes|是)$", str(s).strip(), flags=re.IGNORECASE): return ["yes", "是"] | |||||
| if re.match(r"(false|no|否)$", str(s).strip(), flags=re.IGNORECASE): return ["no", "否"] | |||||
| def column_data_type(arr): | |||||
| uni = len(set([a for a in arr if a is not None])) | |||||
| counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0} | |||||
| trans = {t:f for f,t in [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]} | |||||
| for a in arr: | |||||
| if a is None:continue | |||||
| if re.match(r"[+-]?[0-9]+(\.0+)?$", str(a).replace("%%", "")): | |||||
| counts["int"] += 1 | |||||
| elif re.match(r"[+-]?[0-9.]+$", str(a).replace("%%", "")): | |||||
| counts["float"] += 1 | |||||
| elif re.match(r"(true|false|yes|no|是|否)$", str(a), flags=re.IGNORECASE): | |||||
| counts["bool"] += 1 | |||||
| elif trans_datatime(str(a)): | |||||
| counts["datetime"] += 1 | |||||
| else: counts["text"] += 1 | |||||
| counts = sorted(counts.items(), key=lambda x: x[1]*-1) | |||||
| ty = counts[0][0] | |||||
| for i in range(len(arr)): | |||||
| if arr[i] is None:continue | |||||
| try: | |||||
| arr[i] = trans[ty](str(arr[i])) | |||||
| except Exception as e: | |||||
| arr[i] = None | |||||
| if ty == "text": | |||||
| if len(arr) > 128 and uni/len(arr) < 0.1: | |||||
| ty = "keyword" | |||||
| return arr, ty | |||||
| def chunk(filename, binary=None, callback=None, **kwargs): | |||||
| dfs = [] | |||||
| if re.search(r"\.xlsx?$", filename, re.IGNORECASE): | |||||
| callback(0.1, "Start to parse.") | |||||
| excel_parser = Excel() | |||||
| dfs = excel_parser(filename, binary, callback) | |||||
| elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE): | |||||
| callback(0.1, "Start to parse.") | |||||
| txt = "" | |||||
| if binary: | |||||
| txt = binary.decode("utf-8") | |||||
| else: | |||||
| with open(filename, "r") as f: | |||||
| while True: | |||||
| l = f.readline() | |||||
| if not l: break | |||||
| txt += l | |||||
| lines = txt.split("\n") | |||||
| fails = [] | |||||
| headers = lines[0].split(kwargs.get("delimiter", "\t")) | |||||
| rows = [] | |||||
| for i, line in enumerate(lines[1:]): | |||||
| row = [l for l in line.split(kwargs.get("delimiter", "\t"))] | |||||
| if len(row) != len(headers): | |||||
| fails.append(str(i)) | |||||
| continue | |||||
| rows.append(row) | |||||
| if len(rows) % 999 == 0: | |||||
| callback(len(rows) * 0.6 / len(lines), ("Extract records: {}".format(len(rows)) + ( | |||||
| f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) | |||||
| callback(0.6, ("Extract records: {}".format(len(rows)) + ( | |||||
| f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) | |||||
| dfs = [pd.DataFrame(np.array(rows), columns=headers)] | |||||
| else: raise NotImplementedError("file type not supported yet(excel, text, csv supported)") | |||||
| res = [] | |||||
| PY = Pinyin() | |||||
| fieds_map = {"text": "_tks", "int": "_int", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"} | |||||
| for df in dfs: | |||||
| for n in ["id", "_id", "index", "idx"]: | |||||
| if n in df.columns:del df[n] | |||||
| clmns = df.columns.values | |||||
| txts = list(copy.deepcopy(clmns)) | |||||
| py_clmns = [PY.get_pinyins(n)[0].replace("-", "_") for n in clmns] | |||||
| clmn_tys = [] | |||||
| for j in range(len(clmns)): | |||||
| cln,ty = column_data_type(df[clmns[j]]) | |||||
| clmn_tys.append(ty) | |||||
| df[clmns[j]] = cln | |||||
| if ty == "text": txts.extend([str(c) for c in cln if c]) | |||||
| clmns_map = [(py_clmns[j] + fieds_map[clmn_tys[j]], clmns[j]) for i in range(len(clmns))] | |||||
| # TODO: set this column map to KB parser configuration | |||||
| eng = is_english(txts) | |||||
| for ii,row in df.iterrows(): | |||||
| d = {} | |||||
| row_txt = [] | |||||
| for j in range(len(clmns)): | |||||
| if row[clmns[j]] is None:continue | |||||
| fld = clmns_map[j][0] | |||||
| d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else huqie.qie(row[clmns[j]]) | |||||
| row_txt.append("{}:{}".format(clmns[j], row[clmns[j]])) | |||||
| if not row_txt:continue | |||||
| tokenize(d, "; ".join(row_txt), eng) | |||||
| print(d) | |||||
| res.append(d) | |||||
| callback(0.6, "") | |||||
| return res | |||||
| if __name__== "__main__": | |||||
| import sys | |||||
| def dummy(a, b): | |||||
| pass | |||||
| chunk(sys.argv[1], callback=dummy) | |||||
| ps = int(req.get("size", 1000)) | ps = int(req.get("size", 1000)) | ||||
| src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", | src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", | ||||
| "image_id", "doc_id", "q_512_vec", "q_768_vec", | "image_id", "doc_id", "q_512_vec", "q_768_vec", | ||||
| "q_1024_vec", "q_1536_vec", "available_int"]) | |||||
| "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"]) | |||||
| s = s.query(bqry)[pg * ps:(pg + 1) * ps] | s = s.query(bqry)[pg * ps:(pg + 1) * ps] | ||||
| s = s.highlight("content_ltks") | s = s.highlight("content_ltks") | ||||
| sres.field[i].get("q_%d_vec" % len(sres.query_vector), "\t".join(["0"] * len(sres.query_vector)))) for i in sres.ids] | sres.field[i].get("q_%d_vec" % len(sres.query_vector), "\t".join(["0"] * len(sres.query_vector)))) for i in sres.ids] | ||||
| if not ins_embd: | if not ins_embd: | ||||
| return [], [], [] | return [], [], [] | ||||
| ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ") | |||||
| ins_tw = [sres.field[i][cfield].split(" ") | |||||
| for i in sres.ids] | for i in sres.ids] | ||||
| sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector, | sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector, | ||||
| ins_embd, | ins_embd, | ||||
| d = { | d = { | ||||
| "chunk_id": id, | "chunk_id": id, | ||||
| "content_ltks": sres.field[id]["content_ltks"], | "content_ltks": sres.field[id]["content_ltks"], | ||||
| "content_with_weight": sres.field[id]["content_with_weight"], | |||||
| "doc_id": sres.field[id]["doc_id"], | "doc_id": sres.field[id]["doc_id"], | ||||
| "docnm_kwd": dnm, | "docnm_kwd": dnm, | ||||
| "kb_id": sres.field[id]["kb_id"], | "kb_id": sres.field[id]["kb_id"], |
| import copy | import copy | ||||
| import random | |||||
| from .pdf_parser import HuParser as PdfParser | from .pdf_parser import HuParser as PdfParser | ||||
| from .docx_parser import HuDocxParser as DocxParser | from .docx_parser import HuDocxParser as DocxParser | ||||
| ] | ] | ||||
| ] | ] | ||||
| def random_choices(arr, k): | |||||
| k = min(len(arr), k) | |||||
| return random.choices(arr, k=k) | |||||
| def bullets_category(sections): | def bullets_category(sections): | ||||
| global BULLET_PATTERN | global BULLET_PATTERN |
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
| import os | |||||
| import random | import random | ||||
| from functools import partial | |||||
| import fitz | import fitz | ||||
| import requests | |||||
| import xgboost as xgb | import xgboost as xgb | ||||
| from io import BytesIO | from io import BytesIO | ||||
| import torch | import torch | ||||
| import logging | import logging | ||||
| from PIL import Image | from PIL import Image | ||||
| import numpy as np | import numpy as np | ||||
| from api.db import ParserType | |||||
| from rag.nlp import huqie | from rag.nlp import huqie | ||||
| from collections import Counter | from collections import Counter | ||||
| from copy import deepcopy | from copy import deepcopy | ||||
| from rag.cv.table_recognize import TableTransformer | |||||
| from rag.cv.ppdetection import PPDet | |||||
| from huggingface_hub import hf_hub_download | from huggingface_hub import hf_hub_download | ||||
| logging.getLogger("pdfminer").setLevel(logging.WARNING) | logging.getLogger("pdfminer").setLevel(logging.WARNING) | ||||
| from paddleocr import PaddleOCR | from paddleocr import PaddleOCR | ||||
| logging.getLogger("ppocr").setLevel(logging.ERROR) | logging.getLogger("ppocr").setLevel(logging.ERROR) | ||||
| self.ocr = PaddleOCR(use_angle_cls=False, lang="ch") | self.ocr = PaddleOCR(use_angle_cls=False, lang="ch") | ||||
| self.layouter = PPDet("/data/newpeak/medical-gpt/res/ppdet") | |||||
| self.tbl_det = PPDet("/data/newpeak/medical-gpt/res/ppdet.tbl") | |||||
| if not hasattr(self, "model_speciess"): | |||||
| self.model_speciess = ParserType.GENERAL.value | |||||
| self.layouter = partial(self.__remote_call, self.model_speciess) | |||||
| self.tbl_det = partial(self.__remote_call, "table_component") | |||||
| self.updown_cnt_mdl = xgb.Booster() | self.updown_cnt_mdl = xgb.Booster() | ||||
| if torch.cuda.is_available(): | if torch.cuda.is_available(): | ||||
| """ | """ | ||||
| def __remote_call(self, species, images, thr=0.7): | |||||
| url = os.environ.get("INFINIFLOW_SERVER") | |||||
| if not url:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_SERVER'") | |||||
| token = os.environ.get("INFINIFLOW_TOKEN") | |||||
| if not token:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_TOKEN'") | |||||
| def convert_image_to_bytes(PILimage): | |||||
| image = BytesIO() | |||||
| PILimage.save(image, format='png') | |||||
| image.seek(0) | |||||
| return image.getvalue() | |||||
| images = [convert_image_to_bytes(img) for img in images] | |||||
| def remote_call(): | |||||
| nonlocal images, thr | |||||
| res = requests.post(url+"/v1/layout/detect/"+species, files=[("image", img) for img in images], data={"threashold": thr}, | |||||
| headers={"Authorization": token}, timeout=len(images) * 10) | |||||
| res = res.json() | |||||
| if res["retcode"] != 0: raise RuntimeError(res["retmsg"]) | |||||
| return res["data"] | |||||
| for _ in range(3): | |||||
| try: | |||||
| return remote_call() | |||||
| except RuntimeError as e: | |||||
| raise e | |||||
| except Exception as e: | |||||
| logging.error("layout_predict:"+str(e)) | |||||
| return remote_call() | |||||
| def __char_width(self, c): | def __char_width(self, c): | ||||
| return (c["x1"] - c["x0"]) // len(c["text"]) | return (c["x1"] - c["x0"]) // len(c["text"]) | ||||
| return layouts | return layouts | ||||
| def __table_paddle(self, images): | def __table_paddle(self, images): | ||||
| tbls = self.tbl_det([np.array(img) for img in images], thr=0.5) | |||||
| tbls = self.tbl_det(images, thr=0.5) | |||||
| res = [] | res = [] | ||||
| # align left&right for rows, align top&bottom for columns | # align left&right for rows, align top&bottom for columns | ||||
| for tbl in tbls: | for tbl in tbls: | ||||
| assert len(self.page_images) == len(self.boxes) | assert len(self.page_images) == len(self.boxes) | ||||
| # Tag layout type | # Tag layout type | ||||
| boxes = [] | boxes = [] | ||||
| layouts = self.layouter([np.array(img) for img in self.page_images]) | |||||
| layouts = self.layouter(self.page_images) | |||||
| assert len(self.page_images) == len(layouts) | assert len(self.page_images) == len(layouts) | ||||
| for pn, lts in enumerate(layouts): | for pn, lts in enumerate(layouts): | ||||
| bxs = self.boxes[pn] | bxs = self.boxes[pn] | ||||
| self.__ocr_paddle(i + 1, img, chars, zoomin) | self.__ocr_paddle(i + 1, img, chars, zoomin) | ||||
| if not self.is_english and not any([c for c in self.page_chars]) and self.boxes: | if not self.is_english and not any([c for c in self.page_chars]) and self.boxes: | ||||
| self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join([b["text"] for b in random.choices([b for bxs in self.boxes for b in bxs], k=30)])) | |||||
| bxes = [b for bxs in self.boxes for b in bxs] | |||||
| self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join([b["text"] for b in random.choices(bxes, k=min(30, len(bxes)))])) | |||||
| logging.info("Is it English:", self.is_english) | logging.info("Is it English:", self.is_english) | ||||
| while True: | while True: | ||||
| dispatch() | dispatch() | ||||
| time.sleep(3) | |||||
| time.sleep(1) | |||||
| update_progress() | update_progress() |
| from io import BytesIO | from io import BytesIO | ||||
| import pandas as pd | import pandas as pd | ||||
| from rag.app import laws, paper, presentation, manual, qa | |||||
| from rag.app import laws, paper, presentation, manual, qa, table,book | |||||
| from api.db import LLMType, ParserType | from api.db import LLMType, ParserType | ||||
| from api.db.services.document_service import DocumentService | from api.db.services.document_service import DocumentService | ||||
| FACTORY = { | FACTORY = { | ||||
| ParserType.GENERAL.value: laws, | ParserType.GENERAL.value: laws, | ||||
| ParserType.PAPER.value: paper, | ParserType.PAPER.value: paper, | ||||
| ParserType.BOOK.value: book, | |||||
| ParserType.PRESENTATION.value: presentation, | ParserType.PRESENTATION.value: presentation, | ||||
| ParserType.MANUAL.value: manual, | ParserType.MANUAL.value: manual, | ||||
| ParserType.LAWS.value: laws, | ParserType.LAWS.value: laws, | ||||
| ParserType.QA.value: qa, | ParserType.QA.value: qa, | ||||
| ParserType.TABLE.value: table, | |||||
| } | } | ||||
| d = {"progress_msg": msg} | d = {"progress_msg": msg} | ||||
| if prog is not None: d["progress"] = prog | if prog is not None: d["progress"] = prog | ||||
| try: | try: | ||||
| TaskService.update_by_id(task_id, d) | |||||
| TaskService.update_progress(task_id, d) | |||||
| except Exception as e: | except Exception as e: | ||||
| cron_logger.error("set_progress:({}), {}".format(task_id, str(e))) | cron_logger.error("set_progress:({}), {}".format(task_id, str(e))) | ||||
| return [] | return [] | ||||
| callback = partial(set_progress, row["id"], row["from_page"], row["to_page"]) | callback = partial(set_progress, row["id"], row["from_page"], row["to_page"]) | ||||
| chunker = FACTORY[row["parser_id"]] | |||||
| chunker = FACTORY[row["parser_id"].lower()] | |||||
| try: | try: | ||||
| cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"])) | cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"])) | ||||
| cks = chunker.chunk(row["name"], MINIO.get(row["kb_id"], row["location"]), row["from_page"], row["to_page"], | cks = chunker.chunk(row["name"], MINIO.get(row["kb_id"], row["location"]), row["from_page"], row["to_page"], | ||||
| MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue()) | MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue()) | ||||
| d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"]) | d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"]) | ||||
| del d["image"] | |||||
| docs.append(d) | docs.append(d) | ||||
| return docs | return docs | ||||
| def embedding(docs, mdl): | def embedding(docs, mdl): | ||||
| tts, cnts = [d["docnm_kwd"] for d in docs if d.get("docnm_kwd")], [d["content_with_weight"] for d in docs] | |||||
| tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [d["content_with_weight"] for d in docs] | |||||
| tk_count = 0 | tk_count = 0 | ||||
| if len(tts) == len(cnts): | if len(tts) == len(cnts): | ||||
| tts, c = mdl.encode(tts) | tts, c = mdl.encode(tts) | ||||
| cks = build(r, cv_mdl) | cks = build(r, cv_mdl) | ||||
| if not cks: | if not cks: | ||||
| tmf.write(str(r["update_time"]) + "\n") | tmf.write(str(r["update_time"]) + "\n") | ||||
| callback(1., "No chunk! Done!") | |||||
| continue | continue | ||||
| # TODO: exception handler | # TODO: exception handler | ||||
| ## set_progress(r["did"], -1, "ERROR: ") | ## set_progress(r["did"], -1, "ERROR: ") | ||||
| except Exception as e: | except Exception as e: | ||||
| callback(-1, "Embedding error:{}".format(str(e))) | callback(-1, "Embedding error:{}".format(str(e))) | ||||
| cron_logger.error(str(e)) | cron_logger.error(str(e)) | ||||
| continue | |||||
| callback(msg="Finished embedding! Start to build index!") | callback(msg="Finished embedding! Start to build index!") | ||||
| init_kb(r) | init_kb(r) | ||||
| else: | else: | ||||
| if TaskService.do_cancel(r["id"]): | if TaskService.do_cancel(r["id"]): | ||||
| ELASTICSEARCH.deleteByQuery(Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"])) | ELASTICSEARCH.deleteByQuery(Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"])) | ||||
| continue | |||||
| callback(1., "Done!") | callback(1., "Done!") | ||||
| DocumentService.increment_chunk_num(r["doc_id"], r["kb_id"], tk_count, chunk_count, 0) | DocumentService.increment_chunk_num(r["doc_id"], r["kb_id"], tk_count, chunk_count, 0) | ||||
| cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks))) | cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks))) |