### What problem does this PR solve? Wrap search app ### Type of change - [x] New Feature (non-breaking change which adds functionality)tags/v0.19.1
| @@ -0,0 +1,188 @@ | |||
| # | |||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from flask import request | |||
| from flask_login import current_user, login_required | |||
| from api import settings | |||
| from api.constants import DATASET_NAME_LIMIT | |||
| from api.db import StatusEnum | |||
| from api.db.db_models import DB | |||
| from api.db.services import duplicate_name | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.search_service import SearchService | |||
| from api.db.services.user_service import TenantService, UserTenantService | |||
| from api.utils import get_uuid | |||
| from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request | |||
| @manager.route("/create", methods=["post"]) # noqa: F821 | |||
| @login_required | |||
| @validate_request("name") | |||
| def create(): | |||
| req = request.get_json() | |||
| search_name = req["name"] | |||
| description = req.get("description", "") | |||
| if not isinstance(search_name, str): | |||
| return get_data_error_result(message="Search name must be string.") | |||
| if search_name.strip() == "": | |||
| return get_data_error_result(message="Search name can't be empty.") | |||
| if len(search_name.encode("utf-8")) > DATASET_NAME_LIMIT: | |||
| return get_data_error_result(message=f"Search name length is {len(search_name)} which is large than {DATASET_NAME_LIMIT}") | |||
| e, _ = TenantService.get_by_id(current_user.id) | |||
| if not e: | |||
| return get_data_error_result(message="Authorizationd identity.") | |||
| search_name = search_name.strip() | |||
| search_name = duplicate_name(KnowledgebaseService.query, name=search_name, tenant_id=current_user.id, status=StatusEnum.VALID.value) | |||
| req["id"] = get_uuid() | |||
| req["name"] = search_name | |||
| req["description"] = description | |||
| req["tenant_id"] = current_user.id | |||
| req["created_by"] = current_user.id | |||
| with DB.atomic(): | |||
| try: | |||
| if not SearchService.save(**req): | |||
| return get_data_error_result() | |||
| return get_json_result(data={"search_id": req["id"]}) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route("/update", methods=["post"]) # noqa: F821 | |||
| @login_required | |||
| @validate_request("search_id", "name", "search_config", "tenant_id") | |||
| @not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by") | |||
| def update(): | |||
| req = request.get_json() | |||
| if not isinstance(req["name"], str): | |||
| return get_data_error_result(message="Search name must be string.") | |||
| if req["name"].strip() == "": | |||
| return get_data_error_result(message="Search name can't be empty.") | |||
| if len(req["name"].encode("utf-8")) > DATASET_NAME_LIMIT: | |||
| return get_data_error_result(message=f"Search name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}") | |||
| req["name"] = req["name"].strip() | |||
| tenant_id = req["tenant_id"] | |||
| e, _ = TenantService.get_by_id(tenant_id) | |||
| if not e: | |||
| return get_data_error_result(message="Authorizationd identity.") | |||
| search_id = req["search_id"] | |||
| if not SearchService.accessible4deletion(search_id, current_user.id): | |||
| return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| try: | |||
| search_app = SearchService.query(tenant_id=tenant_id, id=search_id)[0] | |||
| if not search_app: | |||
| return get_json_result(data=False, message=f"Cannot find search {search_id}", code=settings.RetCode.DATA_ERROR) | |||
| if req["name"].lower() != search_app.name.lower() and len(SearchService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) >= 1: | |||
| return get_data_error_result(message="Duplicated search name.") | |||
| if "search_config" in req: | |||
| current_config = search_app.search_config or {} | |||
| new_config = req["search_config"] | |||
| if not isinstance(new_config, dict): | |||
| return get_data_error_result(message="search_config must be a JSON object") | |||
| updated_config = {**current_config, **new_config} | |||
| req["search_config"] = updated_config | |||
| req.pop("search_id", None) | |||
| req.pop("tenant_id", None) | |||
| updated = SearchService.update_by_id(search_id, req) | |||
| if not updated: | |||
| return get_data_error_result(message="Failed to update search") | |||
| e, updated_search = SearchService.get_by_id(search_id) | |||
| if not e: | |||
| return get_data_error_result(message="Failed to fetch updated search") | |||
| return get_json_result(data=updated_search.to_dict()) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route("/detail", methods=["GET"]) # noqa: F821 | |||
| @login_required | |||
| def detail(): | |||
| search_id = request.args["search_id"] | |||
| try: | |||
| tenants = UserTenantService.query(user_id=current_user.id) | |||
| for tenant in tenants: | |||
| if SearchService.query(tenant_id=tenant.tenant_id, id=search_id): | |||
| break | |||
| else: | |||
| return get_json_result(data=False, message="Has no permission for this operation.", code=settings.RetCode.OPERATING_ERROR) | |||
| search = SearchService.get_detail(search_id) | |||
| if not search: | |||
| return get_data_error_result(message="Can't find this Search App!") | |||
| return get_json_result(data=search) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route("/list", methods=["POST"]) # noqa: F821 | |||
| @login_required | |||
| def list_search_app(): | |||
| keywords = request.args.get("keywords", "") | |||
| page_number = int(request.args.get("page", 0)) | |||
| items_per_page = int(request.args.get("page_size", 0)) | |||
| orderby = request.args.get("orderby", "create_time") | |||
| if request.args.get("desc", "true").lower() == "false": | |||
| desc = False | |||
| else: | |||
| desc = True | |||
| req = request.get_json() | |||
| owner_ids = req.get("owner_ids", []) | |||
| try: | |||
| if not owner_ids: | |||
| tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) | |||
| tenants = [m["tenant_id"] for m in tenants] | |||
| search_apps, total = SearchService.get_by_tenant_ids(tenants, current_user.id, page_number, items_per_page, orderby, desc, keywords) | |||
| else: | |||
| tenants = owner_ids | |||
| search_apps, total = SearchService.get_by_tenant_ids(tenants, current_user.id, 0, 0, orderby, desc, keywords) | |||
| search_apps = [search_app for search_app in search_apps if search_app["tenant_id"] in tenants] | |||
| total = len(search_apps) | |||
| if page_number and items_per_page: | |||
| search_apps = search_apps[(page_number - 1) * items_per_page : page_number * items_per_page] | |||
| return get_json_result(data={"search_apps": search_apps, "total": total}) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @manager.route("/rm", methods=["post"]) # noqa: F821 | |||
| @login_required | |||
| @validate_request("search_id") | |||
| def rm(): | |||
| req = request.get_json() | |||
| search_id = req["search_id"] | |||
| if not SearchService.accessible4deletion(search_id, current_user.id): | |||
| return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| try: | |||
| if not SearchService.delete_by_id(search_id): | |||
| return get_data_error_result(message=f"Failed to delete search App {search_id}") | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -13,16 +13,16 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import hashlib | |||
| import inspect | |||
| import logging | |||
| import operator | |||
| import os | |||
| import sys | |||
| import typing | |||
| import time | |||
| import typing | |||
| from enum import Enum | |||
| from functools import wraps | |||
| import hashlib | |||
| from flask_login import UserMixin | |||
| from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer | |||
| @@ -264,14 +264,15 @@ class BaseDataBase: | |||
| def with_retry(max_retries=3, retry_delay=1.0): | |||
| """Decorator: Add retry mechanism to database operations | |||
| Args: | |||
| max_retries (int): maximum number of retries | |||
| retry_delay (float): initial retry delay (seconds), will increase exponentially | |||
| Returns: | |||
| decorated function | |||
| """ | |||
| def decorator(func): | |||
| @wraps(func) | |||
| def wrapper(*args, **kwargs): | |||
| @@ -284,26 +285,28 @@ def with_retry(max_retries=3, retry_delay=1.0): | |||
| # get self and method name for logging | |||
| self_obj = args[0] if args else None | |||
| func_name = func.__name__ | |||
| lock_name = getattr(self_obj, 'lock_name', 'unknown') if self_obj else 'unknown' | |||
| lock_name = getattr(self_obj, "lock_name", "unknown") if self_obj else "unknown" | |||
| if retry < max_retries - 1: | |||
| current_delay = retry_delay * (2 ** retry) | |||
| logging.warning(f"{func_name} {lock_name} failed: {str(e)}, retrying ({retry+1}/{max_retries})") | |||
| current_delay = retry_delay * (2**retry) | |||
| logging.warning(f"{func_name} {lock_name} failed: {str(e)}, retrying ({retry + 1}/{max_retries})") | |||
| time.sleep(current_delay) | |||
| else: | |||
| logging.error(f"{func_name} {lock_name} failed after all attempts: {str(e)}") | |||
| if last_exception: | |||
| raise last_exception | |||
| return False | |||
| return wrapper | |||
| return decorator | |||
| class PostgresDatabaseLock: | |||
| def __init__(self, lock_name, timeout=10, db=None): | |||
| self.lock_name = lock_name | |||
| self.lock_id = int(hashlib.md5(lock_name.encode()).hexdigest(), 16) % (2**31-1) | |||
| self.lock_id = int(hashlib.md5(lock_name.encode()).hexdigest(), 16) % (2**31 - 1) | |||
| self.timeout = int(timeout) | |||
| self.db = db if db else DB | |||
| @@ -542,7 +545,7 @@ class LLM(DataBaseModel): | |||
| max_tokens = IntegerField(default=0) | |||
| tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...", index=True) | |||
| is_tools = BooleanField(null=False, help_text="support tools", default=False) | |||
| is_tools = BooleanField(null=False, help_text="support tools", default=False) | |||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True) | |||
| def __str__(self): | |||
| @@ -796,6 +799,50 @@ class UserCanvasVersion(DataBaseModel): | |||
| db_table = "user_canvas_version" | |||
| class Search(DataBaseModel): | |||
| id = CharField(max_length=32, primary_key=True) | |||
| avatar = TextField(null=True, help_text="avatar base64 string") | |||
| tenant_id = CharField(max_length=32, null=False, index=True) | |||
| name = CharField(max_length=128, null=False, help_text="Search name", index=True) | |||
| description = TextField(null=True, help_text="KB description") | |||
| created_by = CharField(max_length=32, null=False, index=True) | |||
| search_config = JSONField( | |||
| null=False, | |||
| default={ | |||
| "kb_ids": [], | |||
| "doc_ids": [], | |||
| "similarity_threshold": 0.0, | |||
| "vector_similarity_weight": 0.3, | |||
| "use_kg": False, | |||
| # rerank settings | |||
| "rerank_id": "", | |||
| "top_k": 1024, | |||
| # chat settings | |||
| "summary": False, | |||
| "chat_id": "", | |||
| "llm_setting": { | |||
| "temperature": 0.1, | |||
| "top_p": 0.3, | |||
| "frequency_penalty": 0.7, | |||
| "presence_penalty": 0.4, | |||
| }, | |||
| "chat_settingcross_languages": [], | |||
| "highlight": False, | |||
| "keyword": False, | |||
| "web_search": False, | |||
| "related_search": False, | |||
| "query_mindmap": False, | |||
| }, | |||
| ) | |||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True) | |||
| def __str__(self): | |||
| return self.name | |||
| class Meta: | |||
| db_table = "search" | |||
| def migrate_db(): | |||
| migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB) | |||
| try: | |||
| @@ -159,6 +159,7 @@ BAD_CITATION_PATTERNS = [ | |||
| re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12 | |||
| ] | |||
| def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): | |||
| max_index = len(kbinfos["chunks"]) | |||
| @@ -555,7 +556,7 @@ def tts(tts_mdl, text): | |||
| return binascii.hexlify(bin).decode("utf-8") | |||
| def ask(question, kb_ids, tenant_id): | |||
| def ask(question, kb_ids, tenant_id, chat_llm_name=None): | |||
| kbs = KnowledgebaseService.get_by_ids(kb_ids) | |||
| embedding_list = list(set([kb.embd_id for kb in kbs])) | |||
| @@ -563,7 +564,7 @@ def ask(question, kb_ids, tenant_id): | |||
| retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler | |||
| embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0]) | |||
| chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) | |||
| chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name) | |||
| max_tokens = chat_mdl.max_length | |||
| tenant_ids = list(set([kb.tenant_id for kb in kbs])) | |||
| kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs)) | |||
| @@ -0,0 +1,110 @@ | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from datetime import datetime | |||
| from peewee import fn | |||
| from api.db import StatusEnum | |||
| from api.db.db_models import DB, Search, User | |||
| from api.db.services.common_service import CommonService | |||
| from api.utils import current_timestamp, datetime_format | |||
| class SearchService(CommonService): | |||
| model = Search | |||
| @classmethod | |||
| def save(cls, **kwargs): | |||
| kwargs["create_time"] = current_timestamp() | |||
| kwargs["create_date"] = datetime_format(datetime.now()) | |||
| kwargs["update_time"] = current_timestamp() | |||
| kwargs["update_date"] = datetime_format(datetime.now()) | |||
| obj = cls.model.create(**kwargs) | |||
| return obj | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def accessible4deletion(cls, search_id, user_id) -> bool: | |||
| search = ( | |||
| cls.model.select(cls.model.id) | |||
| .where( | |||
| cls.model.id == search_id, | |||
| cls.model.created_by == user_id, | |||
| cls.model.status == StatusEnum.VALID.value, | |||
| ) | |||
| .first() | |||
| ) | |||
| return search is not None | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_detail(cls, search_id): | |||
| fields = [ | |||
| cls.model.id, | |||
| cls.model.avatar, | |||
| cls.model.tenant_id, | |||
| cls.model.name, | |||
| cls.model.description, | |||
| cls.model.created_by, | |||
| cls.model.search_config, | |||
| cls.model.update_time, | |||
| User.nickname, | |||
| User.avatar.alias("tenant_avatar"), | |||
| ] | |||
| search = ( | |||
| cls.model.select(*fields) | |||
| .join(User, on=((User.id == cls.model.tenant_id) & (User.status == StatusEnum.VALID.value))) | |||
| .where((cls.model.id == search_id) & (cls.model.status == StatusEnum.VALID.value)) | |||
| .first() | |||
| .to_dict() | |||
| ) | |||
| return search | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc, keywords): | |||
| fields = [ | |||
| cls.model.id, | |||
| cls.model.avatar, | |||
| cls.model.tenant_id, | |||
| cls.model.name, | |||
| cls.model.description, | |||
| cls.model.created_by, | |||
| cls.model.status, | |||
| cls.model.update_time, | |||
| cls.model.create_time, | |||
| User.nickname, | |||
| User.avatar.alias("tenant_avatar"), | |||
| ] | |||
| query = ( | |||
| cls.model.select(*fields) | |||
| .join(User, on=(cls.model.tenant_id == User.id)) | |||
| .where(((cls.model.tenant_id.in_(joined_tenant_ids)) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value)) | |||
| ) | |||
| if keywords: | |||
| query = query.where(fn.LOWER(cls.model.name).contains(keywords.lower())) | |||
| if desc: | |||
| query = query.order_by(cls.model.getter_by(orderby).desc()) | |||
| else: | |||
| query = query.order_by(cls.model.getter_by(orderby).asc()) | |||
| count = query.count() | |||
| if page_number and items_per_page: | |||
| query = query.paginate(page_number, items_per_page) | |||
| return list(query.dicts()), count | |||