|
|
|
@@ -1,10 +1,10 @@ |
|
|
|
from typing import Literal |
|
|
|
from typing import Any, Literal, cast |
|
|
|
|
|
|
|
from flask import request |
|
|
|
from flask_restx import marshal, reqparse |
|
|
|
from werkzeug.exceptions import Forbidden, NotFound |
|
|
|
|
|
|
|
import services.dataset_service |
|
|
|
import services |
|
|
|
from controllers.service_api import service_api_ns |
|
|
|
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError |
|
|
|
from controllers.service_api.wraps import ( |
|
|
|
@@ -254,19 +254,21 @@ class DatasetListApi(DatasetApiResource): |
|
|
|
"""Resource for creating datasets.""" |
|
|
|
args = dataset_create_parser.parse_args() |
|
|
|
|
|
|
|
if args.get("embedding_model_provider"): |
|
|
|
DatasetService.check_embedding_model_setting( |
|
|
|
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model") |
|
|
|
) |
|
|
|
embedding_model_provider = args.get("embedding_model_provider") |
|
|
|
embedding_model = args.get("embedding_model") |
|
|
|
if embedding_model_provider and embedding_model: |
|
|
|
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model) |
|
|
|
|
|
|
|
retrieval_model = args.get("retrieval_model") |
|
|
|
if ( |
|
|
|
args.get("retrieval_model") |
|
|
|
and args.get("retrieval_model").get("reranking_model") |
|
|
|
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name") |
|
|
|
retrieval_model |
|
|
|
and retrieval_model.get("reranking_model") |
|
|
|
and retrieval_model.get("reranking_model").get("reranking_provider_name") |
|
|
|
): |
|
|
|
DatasetService.check_reranking_model_setting( |
|
|
|
tenant_id, |
|
|
|
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), |
|
|
|
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), |
|
|
|
retrieval_model.get("reranking_model").get("reranking_provider_name"), |
|
|
|
retrieval_model.get("reranking_model").get("reranking_model_name"), |
|
|
|
) |
|
|
|
|
|
|
|
try: |
|
|
|
@@ -317,7 +319,7 @@ class DatasetApi(DatasetApiResource): |
|
|
|
DatasetService.check_dataset_permission(dataset, current_user) |
|
|
|
except services.errors.account.NoPermissionError as e: |
|
|
|
raise Forbidden(str(e)) |
|
|
|
data = marshal(dataset, dataset_detail_fields) |
|
|
|
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) |
|
|
|
# check embedding setting |
|
|
|
provider_manager = ProviderManager() |
|
|
|
assert isinstance(current_user, Account) |
|
|
|
@@ -331,8 +333,8 @@ class DatasetApi(DatasetApiResource): |
|
|
|
for embedding_model in embedding_models: |
|
|
|
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") |
|
|
|
|
|
|
|
if data["indexing_technique"] == "high_quality": |
|
|
|
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" |
|
|
|
if data.get("indexing_technique") == "high_quality": |
|
|
|
item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}" |
|
|
|
if item_model in model_names: |
|
|
|
data["embedding_available"] = True |
|
|
|
else: |
|
|
|
@@ -341,7 +343,9 @@ class DatasetApi(DatasetApiResource): |
|
|
|
data["embedding_available"] = True |
|
|
|
|
|
|
|
# force update search method to keyword_search if indexing_technique is economic |
|
|
|
data["retrieval_model_dict"]["search_method"] = "keyword_search" |
|
|
|
retrieval_model_dict = data.get("retrieval_model_dict") |
|
|
|
if retrieval_model_dict: |
|
|
|
retrieval_model_dict["search_method"] = "keyword_search" |
|
|
|
|
|
|
|
if data.get("permission") == "partial_members": |
|
|
|
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) |
|
|
|
@@ -372,19 +376,24 @@ class DatasetApi(DatasetApiResource): |
|
|
|
data = request.get_json() |
|
|
|
|
|
|
|
# check embedding model setting |
|
|
|
if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"): |
|
|
|
DatasetService.check_embedding_model_setting( |
|
|
|
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") |
|
|
|
) |
|
|
|
embedding_model_provider = data.get("embedding_model_provider") |
|
|
|
embedding_model = data.get("embedding_model") |
|
|
|
if data.get("indexing_technique") == "high_quality" or embedding_model_provider: |
|
|
|
if embedding_model_provider and embedding_model: |
|
|
|
DatasetService.check_embedding_model_setting( |
|
|
|
dataset.tenant_id, embedding_model_provider, embedding_model |
|
|
|
) |
|
|
|
|
|
|
|
retrieval_model = data.get("retrieval_model") |
|
|
|
if ( |
|
|
|
data.get("retrieval_model") |
|
|
|
and data.get("retrieval_model").get("reranking_model") |
|
|
|
and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name") |
|
|
|
retrieval_model |
|
|
|
and retrieval_model.get("reranking_model") |
|
|
|
and retrieval_model.get("reranking_model").get("reranking_provider_name") |
|
|
|
): |
|
|
|
DatasetService.check_reranking_model_setting( |
|
|
|
dataset.tenant_id, |
|
|
|
data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), |
|
|
|
data.get("retrieval_model").get("reranking_model").get("reranking_model_name"), |
|
|
|
retrieval_model.get("reranking_model").get("reranking_provider_name"), |
|
|
|
retrieval_model.get("reranking_model").get("reranking_model_name"), |
|
|
|
) |
|
|
|
|
|
|
|
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator |
|
|
|
@@ -397,7 +406,7 @@ class DatasetApi(DatasetApiResource): |
|
|
|
if dataset is None: |
|
|
|
raise NotFound("Dataset not found.") |
|
|
|
|
|
|
|
result_data = marshal(dataset, dataset_detail_fields) |
|
|
|
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) |
|
|
|
assert isinstance(current_user, Account) |
|
|
|
tenant_id = current_user.current_tenant_id |
|
|
|
|
|
|
|
@@ -591,9 +600,10 @@ class DatasetTagsApi(DatasetApiResource): |
|
|
|
|
|
|
|
args = tag_update_parser.parse_args() |
|
|
|
args["type"] = "knowledge" |
|
|
|
tag = TagService.update_tags(args, args.get("tag_id")) |
|
|
|
tag_id = args["tag_id"] |
|
|
|
tag = TagService.update_tags(args, tag_id) |
|
|
|
|
|
|
|
binding_count = TagService.get_tag_binding_count(args.get("tag_id")) |
|
|
|
binding_count = TagService.get_tag_binding_count(tag_id) |
|
|
|
|
|
|
|
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} |
|
|
|
|
|
|
|
@@ -616,7 +626,7 @@ class DatasetTagsApi(DatasetApiResource): |
|
|
|
if not current_user.has_edit_permission: |
|
|
|
raise Forbidden() |
|
|
|
args = tag_delete_parser.parse_args() |
|
|
|
TagService.delete_tag(args.get("tag_id")) |
|
|
|
TagService.delete_tag(args["tag_id"]) |
|
|
|
|
|
|
|
return 204 |
|
|
|
|