### What problem does this PR solve? Add advanced document filter ### Type of change - [x] New Feature (non-breaking change which adds functionality)tags/v0.20.0
| @@ -45,6 +45,7 @@ from rag.utils.storage_factory import STORAGE_IMPL | |||
| from api.db.services.canvas_service import UserCanvasService | |||
| from agent.canvas import Canvas | |||
| from functools import partial | |||
| from pathlib import Path | |||
| @manager.route('/new_token', methods=['POST']) # noqa: F821 | |||
| @@ -439,7 +440,8 @@ def upload(): | |||
| "name": filename, | |||
| "location": location, | |||
| "size": len(blob), | |||
| "thumbnail": thumbnail(filename, blob) | |||
| "thumbnail": thumbnail(filename, blob), | |||
| "suffix": Path(filename).suffix.lstrip("."), | |||
| } | |||
| form_data = request.form | |||
| @@ -17,6 +17,7 @@ import json | |||
| import os.path | |||
| import pathlib | |||
| import re | |||
| from pathlib import Path | |||
| import flask | |||
| from flask import request | |||
| @@ -125,6 +126,7 @@ def web_crawl(): | |||
| "location": location, | |||
| "size": len(blob), | |||
| "thumbnail": thumbnail(filename, blob), | |||
| "suffix": Path(filename).suffix.lstrip("."), | |||
| } | |||
| if doc["type"] == FileType.VISUAL: | |||
| doc["parser_id"] = ParserType.PICTURE.value | |||
| @@ -173,6 +175,7 @@ def create(): | |||
| "created_by": current_user.id, | |||
| "type": FileType.VIRTUAL, | |||
| "name": req["name"], | |||
| "suffix": Path(req["name"]).suffix.lstrip("."), | |||
| "location": "", | |||
| "size": 0, | |||
| } | |||
| @@ -218,8 +221,10 @@ def list_docs(): | |||
| if invalid_types: | |||
| return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}") | |||
| suffix = req.get("suffix", []) | |||
| try: | |||
| docs, tol = DocumentService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types) | |||
| docs, tol = DocumentService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix) | |||
| for doc_item in docs: | |||
| if doc_item["thumbnail"] and not doc_item["thumbnail"].startswith(IMG_BASE64_PREFIX): | |||
| @@ -230,6 +235,45 @@ def list_docs(): | |||
| return server_error_response(e) | |||
| @manager.route("/filter", methods=["POST"]) # noqa: F821 | |||
| @login_required | |||
| def get_filter(): | |||
| req = request.get_json() | |||
| kb_id = req.get("kb_id") | |||
| if not kb_id: | |||
| return get_json_result(data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) | |||
| tenants = UserTenantService.query(user_id=current_user.id) | |||
| for tenant in tenants: | |||
| if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id): | |||
| break | |||
| else: | |||
| return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR) | |||
| keywords = req.get("keywords", "") | |||
| suffix = req.get("suffix", []) | |||
| run_status = req.get("run_status", []) | |||
| if run_status: | |||
| invalid_status = {s for s in run_status if s not in VALID_TASK_STATUS} | |||
| if invalid_status: | |||
| return get_data_error_result(message=f"Invalid filter run status conditions: {', '.join(invalid_status)}") | |||
| types = req.get("types", []) | |||
| if types: | |||
| invalid_types = {t for t in types if t not in VALID_FILE_TYPES} | |||
| if invalid_types: | |||
| return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}") | |||
| try: | |||
| filter, total = DocumentService.get_filter_by_kb_id(kb_id, keywords, run_status, types, suffix) | |||
| return get_json_result(data={"total": total, "filter": filter}) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route("/infos", methods=["POST"]) # noqa: F821 | |||
| @login_required | |||
| def docinfos(): | |||
| @@ -14,6 +14,8 @@ | |||
| # limitations under the License | |||
| # | |||
| from pathlib import Path | |||
| from api.db.services.file2document_service import File2DocumentService | |||
| from api.db.services.file_service import FileService | |||
| @@ -82,6 +84,7 @@ def convert(): | |||
| "created_by": current_user.id, | |||
| "type": file.type, | |||
| "name": file.name, | |||
| "suffix": Path(file.name).suffix.lstrip("."), | |||
| "location": file.location, | |||
| "size": file.size | |||
| }) | |||
| @@ -634,6 +634,7 @@ class Document(DataBaseModel): | |||
| process_begin_at = DateTimeField(null=True, index=True) | |||
| process_duration = FloatField(default=0) | |||
| meta_fields = JSONField(null=True, default={}) | |||
| suffix = CharField(max_length=32, null=False, help_text="The real file extension suffix", index=True) | |||
| run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0", index=True) | |||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True) | |||
| @@ -960,3 +961,7 @@ def migrate_db(): | |||
| migrate(migrator.rename_column("document", "process_duation", "process_duration")) | |||
| except Exception: | |||
| pass | |||
| try: | |||
| migrate(migrator.add_column("document", "suffix", CharField(max_length=32, null=False, default="", help_text="The real file extension suffix", index=True))) | |||
| except Exception: | |||
| pass | |||
| @@ -72,7 +72,7 @@ class DocumentService(CommonService): | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_by_kb_id(cls, kb_id, page_number, items_per_page, | |||
| orderby, desc, keywords, run_status, types): | |||
| orderby, desc, keywords, run_status, types, suffix): | |||
| if keywords: | |||
| docs = cls.model.select().where( | |||
| (cls.model.kb_id == kb_id), | |||
| @@ -85,6 +85,8 @@ class DocumentService(CommonService): | |||
| docs = docs.where(cls.model.run.in_(run_status)) | |||
| if types: | |||
| docs = docs.where(cls.model.type.in_(types)) | |||
| if suffix: | |||
| docs = docs.where(cls.model.suffix.in_(suffix)) | |||
| count = docs.count() | |||
| if desc: | |||
| @@ -98,6 +100,54 @@ class DocumentService(CommonService): | |||
| return list(docs.dicts()), count | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_filter_by_kb_id(cls, kb_id, keywords, run_status, types, suffix): | |||
| """ | |||
| returns: | |||
| { | |||
| "suffix": { | |||
| "ppt": 1, | |||
| "doxc": 2 | |||
| }, | |||
| "run_status": { | |||
| "1": 2, | |||
| "2": 2 | |||
| } | |||
| }, total | |||
| where "1" => RUNNING, "2" => CANCEL | |||
| """ | |||
| if keywords: | |||
| query = cls.model.select().where( | |||
| (cls.model.kb_id == kb_id), | |||
| (fn.LOWER(cls.model.name).contains(keywords.lower())) | |||
| ) | |||
| else: | |||
| query = cls.model.select().where(cls.model.kb_id == kb_id) | |||
| if run_status: | |||
| query = query.where(cls.model.run.in_(run_status)) | |||
| if types: | |||
| query = query.where(cls.model.type.in_(types)) | |||
| if suffix: | |||
| query = query.where(cls.model.suffix.in_(suffix)) | |||
| rows = query.select(cls.model.run, cls.model.suffix) | |||
| total = rows.count() | |||
| suffix_counter = {} | |||
| run_status_counter = {} | |||
| for row in rows: | |||
| suffix_counter[row.suffix] = suffix_counter.get(row.suffix, 0) + 1 | |||
| run_status_counter[str(row.run)] = run_status_counter.get(str(row.run), 0) + 1 | |||
| return { | |||
| "suffix": suffix_counter, | |||
| "run_status": run_status_counter | |||
| }, total | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def count_by_kb_id(cls, kb_id, keywords, run_status, types): | |||
| @@ -17,6 +17,7 @@ import logging | |||
| import os | |||
| import re | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| from pathlib import Path | |||
| from flask_login import current_user | |||
| from peewee import fn | |||
| @@ -446,6 +447,7 @@ class FileService(CommonService): | |||
| "created_by": user_id, | |||
| "type": filetype, | |||
| "name": filename, | |||
| "suffix": Path(filename).suffix.lstrip("."), | |||
| "location": location, | |||
| "size": len(blob), | |||
| "thumbnail": thumbnail_location, | |||