### What problem does this PR solve? Wrap search app ### Type of change - [x] New Feature (non-breaking change which adds functionality)tags/v0.19.1
| # | |||||
| # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| from flask import request | |||||
| from flask_login import current_user, login_required | |||||
| from api import settings | |||||
| from api.constants import DATASET_NAME_LIMIT | |||||
| from api.db import StatusEnum | |||||
| from api.db.db_models import DB | |||||
| from api.db.services import duplicate_name | |||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||||
| from api.db.services.search_service import SearchService | |||||
| from api.db.services.user_service import TenantService, UserTenantService | |||||
| from api.utils import get_uuid | |||||
| from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request | |||||
| @manager.route("/create", methods=["post"]) # noqa: F821 | |||||
| @login_required | |||||
| @validate_request("name") | |||||
| def create(): | |||||
| req = request.get_json() | |||||
| search_name = req["name"] | |||||
| description = req.get("description", "") | |||||
| if not isinstance(search_name, str): | |||||
| return get_data_error_result(message="Search name must be string.") | |||||
| if search_name.strip() == "": | |||||
| return get_data_error_result(message="Search name can't be empty.") | |||||
| if len(search_name.encode("utf-8")) > DATASET_NAME_LIMIT: | |||||
| return get_data_error_result(message=f"Search name length is {len(search_name)} which is large than {DATASET_NAME_LIMIT}") | |||||
| e, _ = TenantService.get_by_id(current_user.id) | |||||
| if not e: | |||||
| return get_data_error_result(message="Authorizationd identity.") | |||||
| search_name = search_name.strip() | |||||
| search_name = duplicate_name(KnowledgebaseService.query, name=search_name, tenant_id=current_user.id, status=StatusEnum.VALID.value) | |||||
| req["id"] = get_uuid() | |||||
| req["name"] = search_name | |||||
| req["description"] = description | |||||
| req["tenant_id"] = current_user.id | |||||
| req["created_by"] = current_user.id | |||||
| with DB.atomic(): | |||||
| try: | |||||
| if not SearchService.save(**req): | |||||
| return get_data_error_result() | |||||
| return get_json_result(data={"search_id": req["id"]}) | |||||
| except Exception as e: | |||||
| return server_error_response(e) | |||||
| @manager.route("/update", methods=["post"]) # noqa: F821 | |||||
| @login_required | |||||
| @validate_request("search_id", "name", "search_config", "tenant_id") | |||||
| @not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by") | |||||
| def update(): | |||||
| req = request.get_json() | |||||
| if not isinstance(req["name"], str): | |||||
| return get_data_error_result(message="Search name must be string.") | |||||
| if req["name"].strip() == "": | |||||
| return get_data_error_result(message="Search name can't be empty.") | |||||
| if len(req["name"].encode("utf-8")) > DATASET_NAME_LIMIT: | |||||
| return get_data_error_result(message=f"Search name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}") | |||||
| req["name"] = req["name"].strip() | |||||
| tenant_id = req["tenant_id"] | |||||
| e, _ = TenantService.get_by_id(tenant_id) | |||||
| if not e: | |||||
| return get_data_error_result(message="Authorizationd identity.") | |||||
| search_id = req["search_id"] | |||||
| if not SearchService.accessible4deletion(search_id, current_user.id): | |||||
| return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) | |||||
| try: | |||||
| search_app = SearchService.query(tenant_id=tenant_id, id=search_id)[0] | |||||
| if not search_app: | |||||
| return get_json_result(data=False, message=f"Cannot find search {search_id}", code=settings.RetCode.DATA_ERROR) | |||||
| if req["name"].lower() != search_app.name.lower() and len(SearchService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) >= 1: | |||||
| return get_data_error_result(message="Duplicated search name.") | |||||
| if "search_config" in req: | |||||
| current_config = search_app.search_config or {} | |||||
| new_config = req["search_config"] | |||||
| if not isinstance(new_config, dict): | |||||
| return get_data_error_result(message="search_config must be a JSON object") | |||||
| updated_config = {**current_config, **new_config} | |||||
| req["search_config"] = updated_config | |||||
| req.pop("search_id", None) | |||||
| req.pop("tenant_id", None) | |||||
| updated = SearchService.update_by_id(search_id, req) | |||||
| if not updated: | |||||
| return get_data_error_result(message="Failed to update search") | |||||
| e, updated_search = SearchService.get_by_id(search_id) | |||||
| if not e: | |||||
| return get_data_error_result(message="Failed to fetch updated search") | |||||
| return get_json_result(data=updated_search.to_dict()) | |||||
| except Exception as e: | |||||
| return server_error_response(e) | |||||
| @manager.route("/detail", methods=["GET"]) # noqa: F821 | |||||
| @login_required | |||||
| def detail(): | |||||
| search_id = request.args["search_id"] | |||||
| try: | |||||
| tenants = UserTenantService.query(user_id=current_user.id) | |||||
| for tenant in tenants: | |||||
| if SearchService.query(tenant_id=tenant.tenant_id, id=search_id): | |||||
| break | |||||
| else: | |||||
| return get_json_result(data=False, message="Has no permission for this operation.", code=settings.RetCode.OPERATING_ERROR) | |||||
| search = SearchService.get_detail(search_id) | |||||
| if not search: | |||||
| return get_data_error_result(message="Can't find this Search App!") | |||||
| return get_json_result(data=search) | |||||
| except Exception as e: | |||||
| return server_error_response(e) | |||||
| @manager.route("/list", methods=["POST"]) # noqa: F821 | |||||
| @login_required | |||||
| def list_search_app(): | |||||
| keywords = request.args.get("keywords", "") | |||||
| page_number = int(request.args.get("page", 0)) | |||||
| items_per_page = int(request.args.get("page_size", 0)) | |||||
| orderby = request.args.get("orderby", "create_time") | |||||
| if request.args.get("desc", "true").lower() == "false": | |||||
| desc = False | |||||
| else: | |||||
| desc = True | |||||
| req = request.get_json() | |||||
| owner_ids = req.get("owner_ids", []) | |||||
| try: | |||||
| if not owner_ids: | |||||
| tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) | |||||
| tenants = [m["tenant_id"] for m in tenants] | |||||
| search_apps, total = SearchService.get_by_tenant_ids(tenants, current_user.id, page_number, items_per_page, orderby, desc, keywords) | |||||
| else: | |||||
| tenants = owner_ids | |||||
| search_apps, total = SearchService.get_by_tenant_ids(tenants, current_user.id, 0, 0, orderby, desc, keywords) | |||||
| search_apps = [search_app for search_app in search_apps if search_app["tenant_id"] in tenants] | |||||
| total = len(search_apps) | |||||
| if page_number and items_per_page: | |||||
| search_apps = search_apps[(page_number - 1) * items_per_page : page_number * items_per_page] | |||||
| return get_json_result(data={"search_apps": search_apps, "total": total}) | |||||
| except Exception as e: | |||||
| return server_error_response(e) | |||||
| @manager.route("/rm", methods=["post"]) # noqa: F821 | |||||
| @login_required | |||||
| @validate_request("search_id") | |||||
| def rm(): | |||||
| req = request.get_json() | |||||
| search_id = req["search_id"] | |||||
| if not SearchService.accessible4deletion(search_id, current_user.id): | |||||
| return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) | |||||
| try: | |||||
| if not SearchService.delete_by_id(search_id): | |||||
| return get_data_error_result(message=f"Failed to delete search App {search_id}") | |||||
| return get_json_result(data=True) | |||||
| except Exception as e: | |||||
| return server_error_response(e) |
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| import hashlib | |||||
| import inspect | import inspect | ||||
| import logging | import logging | ||||
| import operator | import operator | ||||
| import os | import os | ||||
| import sys | import sys | ||||
| import typing | |||||
| import time | import time | ||||
| import typing | |||||
| from enum import Enum | from enum import Enum | ||||
| from functools import wraps | from functools import wraps | ||||
| import hashlib | |||||
| from flask_login import UserMixin | from flask_login import UserMixin | ||||
| from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer | from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer | ||||
| def with_retry(max_retries=3, retry_delay=1.0): | def with_retry(max_retries=3, retry_delay=1.0): | ||||
| """Decorator: Add retry mechanism to database operations | """Decorator: Add retry mechanism to database operations | ||||
| Args: | Args: | ||||
| max_retries (int): maximum number of retries | max_retries (int): maximum number of retries | ||||
| retry_delay (float): initial retry delay (seconds), will increase exponentially | retry_delay (float): initial retry delay (seconds), will increase exponentially | ||||
| Returns: | Returns: | ||||
| decorated function | decorated function | ||||
| """ | """ | ||||
| def decorator(func): | def decorator(func): | ||||
| @wraps(func) | @wraps(func) | ||||
| def wrapper(*args, **kwargs): | def wrapper(*args, **kwargs): | ||||
| # get self and method name for logging | # get self and method name for logging | ||||
| self_obj = args[0] if args else None | self_obj = args[0] if args else None | ||||
| func_name = func.__name__ | func_name = func.__name__ | ||||
| lock_name = getattr(self_obj, 'lock_name', 'unknown') if self_obj else 'unknown' | |||||
| lock_name = getattr(self_obj, "lock_name", "unknown") if self_obj else "unknown" | |||||
| if retry < max_retries - 1: | if retry < max_retries - 1: | ||||
| current_delay = retry_delay * (2 ** retry) | |||||
| logging.warning(f"{func_name} {lock_name} failed: {str(e)}, retrying ({retry+1}/{max_retries})") | |||||
| current_delay = retry_delay * (2**retry) | |||||
| logging.warning(f"{func_name} {lock_name} failed: {str(e)}, retrying ({retry + 1}/{max_retries})") | |||||
| time.sleep(current_delay) | time.sleep(current_delay) | ||||
| else: | else: | ||||
| logging.error(f"{func_name} {lock_name} failed after all attempts: {str(e)}") | logging.error(f"{func_name} {lock_name} failed after all attempts: {str(e)}") | ||||
| if last_exception: | if last_exception: | ||||
| raise last_exception | raise last_exception | ||||
| return False | return False | ||||
| return wrapper | return wrapper | ||||
| return decorator | return decorator | ||||
| class PostgresDatabaseLock: | class PostgresDatabaseLock: | ||||
| def __init__(self, lock_name, timeout=10, db=None): | def __init__(self, lock_name, timeout=10, db=None): | ||||
| self.lock_name = lock_name | self.lock_name = lock_name | ||||
| self.lock_id = int(hashlib.md5(lock_name.encode()).hexdigest(), 16) % (2**31-1) | |||||
| self.lock_id = int(hashlib.md5(lock_name.encode()).hexdigest(), 16) % (2**31 - 1) | |||||
| self.timeout = int(timeout) | self.timeout = int(timeout) | ||||
| self.db = db if db else DB | self.db = db if db else DB | ||||
| max_tokens = IntegerField(default=0) | max_tokens = IntegerField(default=0) | ||||
| tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...", index=True) | tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...", index=True) | ||||
| is_tools = BooleanField(null=False, help_text="support tools", default=False) | |||||
| is_tools = BooleanField(null=False, help_text="support tools", default=False) | |||||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True) | status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True) | ||||
| def __str__(self): | def __str__(self): | ||||
| db_table = "user_canvas_version" | db_table = "user_canvas_version" | ||||
| class Search(DataBaseModel): | |||||
| id = CharField(max_length=32, primary_key=True) | |||||
| avatar = TextField(null=True, help_text="avatar base64 string") | |||||
| tenant_id = CharField(max_length=32, null=False, index=True) | |||||
| name = CharField(max_length=128, null=False, help_text="Search name", index=True) | |||||
| description = TextField(null=True, help_text="KB description") | |||||
| created_by = CharField(max_length=32, null=False, index=True) | |||||
| search_config = JSONField( | |||||
| null=False, | |||||
| default={ | |||||
| "kb_ids": [], | |||||
| "doc_ids": [], | |||||
| "similarity_threshold": 0.0, | |||||
| "vector_similarity_weight": 0.3, | |||||
| "use_kg": False, | |||||
| # rerank settings | |||||
| "rerank_id": "", | |||||
| "top_k": 1024, | |||||
| # chat settings | |||||
| "summary": False, | |||||
| "chat_id": "", | |||||
| "llm_setting": { | |||||
| "temperature": 0.1, | |||||
| "top_p": 0.3, | |||||
| "frequency_penalty": 0.7, | |||||
| "presence_penalty": 0.4, | |||||
| }, | |||||
| "chat_settingcross_languages": [], | |||||
| "highlight": False, | |||||
| "keyword": False, | |||||
| "web_search": False, | |||||
| "related_search": False, | |||||
| "query_mindmap": False, | |||||
| }, | |||||
| ) | |||||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True) | |||||
| def __str__(self): | |||||
| return self.name | |||||
| class Meta: | |||||
| db_table = "search" | |||||
| def migrate_db(): | def migrate_db(): | ||||
| migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB) | migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB) | ||||
| try: | try: |
| re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12 | re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12 | ||||
| ] | ] | ||||
| def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): | def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): | ||||
| max_index = len(kbinfos["chunks"]) | max_index = len(kbinfos["chunks"]) | ||||
| return binascii.hexlify(bin).decode("utf-8") | return binascii.hexlify(bin).decode("utf-8") | ||||
| def ask(question, kb_ids, tenant_id): | |||||
| def ask(question, kb_ids, tenant_id, chat_llm_name=None): | |||||
| kbs = KnowledgebaseService.get_by_ids(kb_ids) | kbs = KnowledgebaseService.get_by_ids(kb_ids) | ||||
| embedding_list = list(set([kb.embd_id for kb in kbs])) | embedding_list = list(set([kb.embd_id for kb in kbs])) | ||||
| retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler | retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler | ||||
| embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0]) | embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0]) | ||||
| chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) | |||||
| chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name) | |||||
| max_tokens = chat_mdl.max_length | max_tokens = chat_mdl.max_length | ||||
| tenant_ids = list(set([kb.tenant_id for kb in kbs])) | tenant_ids = list(set([kb.tenant_id for kb in kbs])) | ||||
| kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs)) | kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs)) |
| # | |||||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| from datetime import datetime | |||||
| from peewee import fn | |||||
| from api.db import StatusEnum | |||||
| from api.db.db_models import DB, Search, User | |||||
| from api.db.services.common_service import CommonService | |||||
| from api.utils import current_timestamp, datetime_format | |||||
| class SearchService(CommonService): | |||||
| model = Search | |||||
| @classmethod | |||||
| def save(cls, **kwargs): | |||||
| kwargs["create_time"] = current_timestamp() | |||||
| kwargs["create_date"] = datetime_format(datetime.now()) | |||||
| kwargs["update_time"] = current_timestamp() | |||||
| kwargs["update_date"] = datetime_format(datetime.now()) | |||||
| obj = cls.model.create(**kwargs) | |||||
| return obj | |||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def accessible4deletion(cls, search_id, user_id) -> bool: | |||||
| search = ( | |||||
| cls.model.select(cls.model.id) | |||||
| .where( | |||||
| cls.model.id == search_id, | |||||
| cls.model.created_by == user_id, | |||||
| cls.model.status == StatusEnum.VALID.value, | |||||
| ) | |||||
| .first() | |||||
| ) | |||||
| return search is not None | |||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def get_detail(cls, search_id): | |||||
| fields = [ | |||||
| cls.model.id, | |||||
| cls.model.avatar, | |||||
| cls.model.tenant_id, | |||||
| cls.model.name, | |||||
| cls.model.description, | |||||
| cls.model.created_by, | |||||
| cls.model.search_config, | |||||
| cls.model.update_time, | |||||
| User.nickname, | |||||
| User.avatar.alias("tenant_avatar"), | |||||
| ] | |||||
| search = ( | |||||
| cls.model.select(*fields) | |||||
| .join(User, on=((User.id == cls.model.tenant_id) & (User.status == StatusEnum.VALID.value))) | |||||
| .where((cls.model.id == search_id) & (cls.model.status == StatusEnum.VALID.value)) | |||||
| .first() | |||||
| .to_dict() | |||||
| ) | |||||
| return search | |||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc, keywords): | |||||
| fields = [ | |||||
| cls.model.id, | |||||
| cls.model.avatar, | |||||
| cls.model.tenant_id, | |||||
| cls.model.name, | |||||
| cls.model.description, | |||||
| cls.model.created_by, | |||||
| cls.model.status, | |||||
| cls.model.update_time, | |||||
| cls.model.create_time, | |||||
| User.nickname, | |||||
| User.avatar.alias("tenant_avatar"), | |||||
| ] | |||||
| query = ( | |||||
| cls.model.select(*fields) | |||||
| .join(User, on=(cls.model.tenant_id == User.id)) | |||||
| .where(((cls.model.tenant_id.in_(joined_tenant_ids)) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value)) | |||||
| ) | |||||
| if keywords: | |||||
| query = query.where(fn.LOWER(cls.model.name).contains(keywords.lower())) | |||||
| if desc: | |||||
| query = query.order_by(cls.model.getter_by(orderby).desc()) | |||||
| else: | |||||
| query = query.order_by(cls.model.getter_by(orderby).asc()) | |||||
| count = query.count() | |||||
| if page_number and items_per_page: | |||||
| query = query.paginate(page_number, items_per_page) | |||||
| return list(query.dicts()), count |