瀏覽代碼

Feat: wrap search app (#8320)

### What problem does this PR solve?

Wrap search app

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
tags/v0.19.1
Yongteng Lei 4 月之前
父節點
當前提交
1b022116d5
沒有連結到貢獻者的電子郵件帳戶。
共有 4 個檔案被更改,包括 359 行新增13 行删除
  1. 188
    0
      api/apps/search_app.py
  2. 58
    11
      api/db/db_models.py
  3. 3
    2
      api/db/services/dialog_service.py
  4. 110
    0
      api/db/services/search_service.py

+ 188
- 0
api/apps/search_app.py 查看文件

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from flask import request
from flask_login import current_user, login_required

from api import settings
from api.constants import DATASET_NAME_LIMIT
from api.db import StatusEnum
from api.db.db_models import DB
from api.db.services import duplicate_name
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.search_service import SearchService
from api.db.services.user_service import TenantService, UserTenantService
from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request


@manager.route("/create", methods=["post"]) # noqa: F821
@login_required
@validate_request("name")
def create():
req = request.get_json()
search_name = req["name"]
description = req.get("description", "")
if not isinstance(search_name, str):
return get_data_error_result(message="Search name must be string.")
if search_name.strip() == "":
return get_data_error_result(message="Search name can't be empty.")
if len(search_name.encode("utf-8")) > DATASET_NAME_LIMIT:
return get_data_error_result(message=f"Search name length is {len(search_name)} which is large than {DATASET_NAME_LIMIT}")
e, _ = TenantService.get_by_id(current_user.id)
if not e:
return get_data_error_result(message="Authorizationd identity.")

search_name = search_name.strip()
search_name = duplicate_name(KnowledgebaseService.query, name=search_name, tenant_id=current_user.id, status=StatusEnum.VALID.value)

req["id"] = get_uuid()
req["name"] = search_name
req["description"] = description
req["tenant_id"] = current_user.id
req["created_by"] = current_user.id
with DB.atomic():
try:
if not SearchService.save(**req):
return get_data_error_result()
return get_json_result(data={"search_id": req["id"]})
except Exception as e:
return server_error_response(e)


@manager.route("/update", methods=["post"]) # noqa: F821
@login_required
@validate_request("search_id", "name", "search_config", "tenant_id")
@not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by")
def update():
req = request.get_json()
if not isinstance(req["name"], str):
return get_data_error_result(message="Search name must be string.")
if req["name"].strip() == "":
return get_data_error_result(message="Search name can't be empty.")
if len(req["name"].encode("utf-8")) > DATASET_NAME_LIMIT:
return get_data_error_result(message=f"Search name length is {len(req['name'])} which is large than {DATASET_NAME_LIMIT}")
req["name"] = req["name"].strip()
tenant_id = req["tenant_id"]
e, _ = TenantService.get_by_id(tenant_id)
if not e:
return get_data_error_result(message="Authorizationd identity.")

search_id = req["search_id"]
if not SearchService.accessible4deletion(search_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)

try:
search_app = SearchService.query(tenant_id=tenant_id, id=search_id)[0]
if not search_app:
return get_json_result(data=False, message=f"Cannot find search {search_id}", code=settings.RetCode.DATA_ERROR)

if req["name"].lower() != search_app.name.lower() and len(SearchService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) >= 1:
return get_data_error_result(message="Duplicated search name.")

if "search_config" in req:
current_config = search_app.search_config or {}
new_config = req["search_config"]

if not isinstance(new_config, dict):
return get_data_error_result(message="search_config must be a JSON object")

updated_config = {**current_config, **new_config}
req["search_config"] = updated_config

req.pop("search_id", None)
req.pop("tenant_id", None)

updated = SearchService.update_by_id(search_id, req)
if not updated:
return get_data_error_result(message="Failed to update search")

e, updated_search = SearchService.get_by_id(search_id)
if not e:
return get_data_error_result(message="Failed to fetch updated search")

return get_json_result(data=updated_search.to_dict())

except Exception as e:
return server_error_response(e)


@manager.route("/detail", methods=["GET"]) # noqa: F821
@login_required
def detail():
search_id = request.args["search_id"]
try:
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if SearchService.query(tenant_id=tenant.tenant_id, id=search_id):
break
else:
return get_json_result(data=False, message="Has no permission for this operation.", code=settings.RetCode.OPERATING_ERROR)

search = SearchService.get_detail(search_id)
if not search:
return get_data_error_result(message="Can't find this Search App!")
return get_json_result(data=search)
except Exception as e:
return server_error_response(e)


@manager.route("/list", methods=["POST"]) # noqa: F821
@login_required
def list_search_app():
keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 0))
items_per_page = int(request.args.get("page_size", 0))
orderby = request.args.get("orderby", "create_time")
if request.args.get("desc", "true").lower() == "false":
desc = False
else:
desc = True

req = request.get_json()
owner_ids = req.get("owner_ids", [])
try:
if not owner_ids:
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
tenants = [m["tenant_id"] for m in tenants]
search_apps, total = SearchService.get_by_tenant_ids(tenants, current_user.id, page_number, items_per_page, orderby, desc, keywords)
else:
tenants = owner_ids
search_apps, total = SearchService.get_by_tenant_ids(tenants, current_user.id, 0, 0, orderby, desc, keywords)
search_apps = [search_app for search_app in search_apps if search_app["tenant_id"] in tenants]
total = len(search_apps)
if page_number and items_per_page:
search_apps = search_apps[(page_number - 1) * items_per_page : page_number * items_per_page]
return get_json_result(data={"search_apps": search_apps, "total": total})
except Exception as e:
return server_error_response(e)


@manager.route("/rm", methods=["post"]) # noqa: F821
@login_required
@validate_request("search_id")
def rm():
req = request.get_json()
search_id = req["search_id"]
if not SearchService.accessible4deletion(search_id, current_user.id):
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR)

try:
if not SearchService.delete_by_id(search_id):
return get_data_error_result(message=f"Failed to delete search App {search_id}")
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)

+ 58
- 11
api/db/db_models.py 查看文件

# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import hashlib
import inspect import inspect
import logging import logging
import operator import operator
import os import os
import sys import sys
import typing
import time import time
import typing
from enum import Enum from enum import Enum
from functools import wraps from functools import wraps
import hashlib


from flask_login import UserMixin from flask_login import UserMixin
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer


def with_retry(max_retries=3, retry_delay=1.0): def with_retry(max_retries=3, retry_delay=1.0):
"""Decorator: Add retry mechanism to database operations """Decorator: Add retry mechanism to database operations
Args: Args:
max_retries (int): maximum number of retries max_retries (int): maximum number of retries
retry_delay (float): initial retry delay (seconds), will increase exponentially retry_delay (float): initial retry delay (seconds), will increase exponentially
Returns: Returns:
decorated function decorated function
""" """

def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
# get self and method name for logging # get self and method name for logging
self_obj = args[0] if args else None self_obj = args[0] if args else None
func_name = func.__name__ func_name = func.__name__
lock_name = getattr(self_obj, 'lock_name', 'unknown') if self_obj else 'unknown'
lock_name = getattr(self_obj, "lock_name", "unknown") if self_obj else "unknown"
if retry < max_retries - 1: if retry < max_retries - 1:
current_delay = retry_delay * (2 ** retry)
logging.warning(f"{func_name} {lock_name} failed: {str(e)}, retrying ({retry+1}/{max_retries})")
current_delay = retry_delay * (2**retry)
logging.warning(f"{func_name} {lock_name} failed: {str(e)}, retrying ({retry + 1}/{max_retries})")
time.sleep(current_delay) time.sleep(current_delay)
else: else:
logging.error(f"{func_name} {lock_name} failed after all attempts: {str(e)}") logging.error(f"{func_name} {lock_name} failed after all attempts: {str(e)}")
if last_exception: if last_exception:
raise last_exception raise last_exception
return False return False

return wrapper return wrapper

return decorator return decorator




class PostgresDatabaseLock: class PostgresDatabaseLock:
def __init__(self, lock_name, timeout=10, db=None): def __init__(self, lock_name, timeout=10, db=None):
self.lock_name = lock_name self.lock_name = lock_name
self.lock_id = int(hashlib.md5(lock_name.encode()).hexdigest(), 16) % (2**31-1)
self.lock_id = int(hashlib.md5(lock_name.encode()).hexdigest(), 16) % (2**31 - 1)
self.timeout = int(timeout) self.timeout = int(timeout)
self.db = db if db else DB self.db = db if db else DB


max_tokens = IntegerField(default=0) max_tokens = IntegerField(default=0)


tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...", index=True) tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...", index=True)
is_tools = BooleanField(null=False, help_text="support tools", default=False)
is_tools = BooleanField(null=False, help_text="support tools", default=False)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True) status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)


def __str__(self): def __str__(self):
db_table = "user_canvas_version" db_table = "user_canvas_version"




class Search(DataBaseModel):
id = CharField(max_length=32, primary_key=True)
avatar = TextField(null=True, help_text="avatar base64 string")
tenant_id = CharField(max_length=32, null=False, index=True)
name = CharField(max_length=128, null=False, help_text="Search name", index=True)
description = TextField(null=True, help_text="KB description")
created_by = CharField(max_length=32, null=False, index=True)
search_config = JSONField(
null=False,
default={
"kb_ids": [],
"doc_ids": [],
"similarity_threshold": 0.0,
"vector_similarity_weight": 0.3,
"use_kg": False,
# rerank settings
"rerank_id": "",
"top_k": 1024,
# chat settings
"summary": False,
"chat_id": "",
"llm_setting": {
"temperature": 0.1,
"top_p": 0.3,
"frequency_penalty": 0.7,
"presence_penalty": 0.4,
},
"chat_settingcross_languages": [],
"highlight": False,
"keyword": False,
"web_search": False,
"related_search": False,
"query_mindmap": False,
},
)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)

def __str__(self):
return self.name

class Meta:
db_table = "search"


def migrate_db(): def migrate_db():
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB) migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
try: try:

+ 3
- 2
api/db/services/dialog_service.py 查看文件

re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12 re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12
] ]



def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
max_index = len(kbinfos["chunks"]) max_index = len(kbinfos["chunks"])


return binascii.hexlify(bin).decode("utf-8") return binascii.hexlify(bin).decode("utf-8")




def ask(question, kb_ids, tenant_id):
def ask(question, kb_ids, tenant_id, chat_llm_name=None):
kbs = KnowledgebaseService.get_by_ids(kb_ids) kbs = KnowledgebaseService.get_by_ids(kb_ids)
embedding_list = list(set([kb.embd_id for kb in kbs])) embedding_list = list(set([kb.embd_id for kb in kbs]))


retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler


embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0]) embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
max_tokens = chat_mdl.max_length max_tokens = chat_mdl.max_length
tenant_ids = list(set([kb.tenant_id for kb in kbs])) tenant_ids = list(set([kb.tenant_id for kb in kbs]))
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs)) kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs))

+ 110
- 0
api/db/services/search_service.py 查看文件

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime

from peewee import fn

from api.db import StatusEnum
from api.db.db_models import DB, Search, User
from api.db.services.common_service import CommonService
from api.utils import current_timestamp, datetime_format


class SearchService(CommonService):
model = Search

@classmethod
def save(cls, **kwargs):
kwargs["create_time"] = current_timestamp()
kwargs["create_date"] = datetime_format(datetime.now())
kwargs["update_time"] = current_timestamp()
kwargs["update_date"] = datetime_format(datetime.now())
obj = cls.model.create(**kwargs)
return obj

@classmethod
@DB.connection_context()
def accessible4deletion(cls, search_id, user_id) -> bool:
search = (
cls.model.select(cls.model.id)
.where(
cls.model.id == search_id,
cls.model.created_by == user_id,
cls.model.status == StatusEnum.VALID.value,
)
.first()
)
return search is not None

@classmethod
@DB.connection_context()
def get_detail(cls, search_id):
fields = [
cls.model.id,
cls.model.avatar,
cls.model.tenant_id,
cls.model.name,
cls.model.description,
cls.model.created_by,
cls.model.search_config,
cls.model.update_time,
User.nickname,
User.avatar.alias("tenant_avatar"),
]
search = (
cls.model.select(*fields)
.join(User, on=((User.id == cls.model.tenant_id) & (User.status == StatusEnum.VALID.value)))
.where((cls.model.id == search_id) & (cls.model.status == StatusEnum.VALID.value))
.first()
.to_dict()
)
return search

@classmethod
@DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc, keywords):
fields = [
cls.model.id,
cls.model.avatar,
cls.model.tenant_id,
cls.model.name,
cls.model.description,
cls.model.created_by,
cls.model.status,
cls.model.update_time,
cls.model.create_time,
User.nickname,
User.avatar.alias("tenant_avatar"),
]
query = (
cls.model.select(*fields)
.join(User, on=(cls.model.tenant_id == User.id))
.where(((cls.model.tenant_id.in_(joined_tenant_ids)) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value))
)

if keywords:
query = query.where(fn.LOWER(cls.model.name).contains(keywords.lower()))
if desc:
query = query.order_by(cls.model.getter_by(orderby).desc())
else:
query = query.order_by(cls.model.getter_by(orderby).asc())

count = query.count()

if page_number and items_per_page:
query = query.paginate(page_number, items_per_page)

return list(query.dicts()), count

Loading…
取消
儲存