소스 검색

Feat/api validate model provider (#21582)

Co-authored-by: crazywoola <427733928@qq.com>
tags/1.5.1
Khoa 4 달 전
부모
커밋
a06af88b26
No account linked to committer's email address
3개의 변경된 파일87개의 추가작업 그리고 2개의 파일을 삭제
  1. 27
    1
      api/controllers/service_api/dataset/dataset.py
  2. 43
    1
      api/controllers/service_api/dataset/document.py
  3. 17
    0
      api/services/dataset_service.py

+ 27
- 1
api/controllers/service_api/dataset/dataset.py 파일 보기

@@ -133,6 +133,22 @@ class DatasetListApi(DatasetApiResource):
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")

args = parser.parse_args()

if args.get("embedding_model_provider"):
DatasetService.check_embedding_model_setting(
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
)
if (
args.get("retrieval_model")
and args.get("retrieval_model").get("reranking_model")
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
):
DatasetService.check_reranking_model_setting(
tenant_id,
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
)

try:
dataset = DatasetService.create_empty_dataset(
tenant_id=tenant_id,
@@ -265,10 +281,20 @@ class DatasetApi(DatasetApiResource):
data = request.get_json()

# check embedding model setting
if data.get("indexing_technique") == "high_quality":
if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"):
DatasetService.check_embedding_model_setting(
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
)
if (
data.get("retrieval_model")
and data.get("retrieval_model").get("reranking_model")
and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
):
DatasetService.check_reranking_model_setting(
dataset.tenant_id,
data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
data.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
)

# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(

+ 43
- 1
api/controllers/service_api/dataset/document.py 파일 보기

@@ -29,7 +29,7 @@ from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment
from services.dataset_service import DocumentService
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from services.file_service import FileService

@@ -59,6 +59,7 @@ class DocumentAddByTextApi(DatasetApiResource):
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")

args = parser.parse_args()

dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@@ -74,6 +75,21 @@ class DocumentAddByTextApi(DatasetApiResource):
if text is None or name is None:
raise ValueError("Both 'text' and 'name' must be non-null values.")

if args.get("embedding_model_provider"):
DatasetService.check_embedding_model_setting(
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
)
if (
args.get("retrieval_model")
and args.get("retrieval_model").get("reranking_model")
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
):
DatasetService.check_reranking_model_setting(
tenant_id,
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
)

upload_file = FileService.upload_text(text=str(text), text_name=str(name))
data_source = {
"type": "upload_file",
@@ -124,6 +140,17 @@ class DocumentUpdateByTextApi(DatasetApiResource):
if not dataset:
raise ValueError("Dataset does not exist.")

if (
args.get("retrieval_model")
and args.get("retrieval_model").get("reranking_model")
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
):
DatasetService.check_reranking_model_setting(
tenant_id,
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
)

# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique

@@ -188,6 +215,21 @@ class DocumentAddByFileApi(DatasetApiResource):
raise ValueError("indexing_technique is required.")
args["indexing_technique"] = indexing_technique

if "embedding_model_provider" in args:
DatasetService.check_embedding_model_setting(
tenant_id, args["embedding_model_provider"], args["embedding_model"]
)
if (
"retrieval_model" in args
and args["retrieval_model"].get("reranking_model")
and args["retrieval_model"].get("reranking_model").get("reranking_provider_name")
):
DatasetService.check_reranking_model_setting(
tenant_id,
args["retrieval_model"].get("reranking_model").get("reranking_provider_name"),
args["retrieval_model"].get("reranking_model").get("reranking_model_name"),
)

# save file info
file = request.files["file"]
# check file

+ 17
- 0
api/services/dataset_service.py 파일 보기

@@ -278,6 +278,23 @@ class DatasetService:
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)

@staticmethod
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=tenant_id,
provider=reranking_model_provider,
model_type=ModelType.RERANK,
model=reranking_model,
)
except LLMBadRequestError:
raise ValueError(
"No Rerank Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)

@staticmethod
def update_dataset(dataset_id, data, user):
"""

Loading…
취소
저장