Просмотр исходного кода

Fix typing errors in dataset API (#26424)

Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
tags/1.9.1
Asuka Minato 1 месяц назад
Родитель
Сommit
d77c2e4d17
Аккаунт пользователя с таким Email не найден

+ 38
- 28
api/controllers/service_api/dataset/dataset.py Просмотреть файл

from typing import Literal
from typing import Any, Literal, cast


from flask import request from flask import request
from flask_restx import marshal, reqparse from flask_restx import marshal, reqparse
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound


import services.dataset_service
import services
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
from controllers.service_api.wraps import ( from controllers.service_api.wraps import (
"""Resource for creating datasets.""" """Resource for creating datasets."""
args = dataset_create_parser.parse_args() args = dataset_create_parser.parse_args()


if args.get("embedding_model_provider"):
DatasetService.check_embedding_model_setting(
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
)
embedding_model_provider = args.get("embedding_model_provider")
embedding_model = args.get("embedding_model")
if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)

retrieval_model = args.get("retrieval_model")
if ( if (
args.get("retrieval_model")
and args.get("retrieval_model").get("reranking_model")
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
retrieval_model
and retrieval_model.get("reranking_model")
and retrieval_model.get("reranking_model").get("reranking_provider_name")
): ):
DatasetService.check_reranking_model_setting( DatasetService.check_reranking_model_setting(
tenant_id, tenant_id,
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
retrieval_model.get("reranking_model").get("reranking_provider_name"),
retrieval_model.get("reranking_model").get("reranking_model_name"),
) )


try: try:
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
# check embedding setting # check embedding setting
provider_manager = ProviderManager() provider_manager = ProviderManager()
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
for embedding_model in embedding_models: for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")


if data["indexing_technique"] == "high_quality":
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
if data.get("indexing_technique") == "high_quality":
item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}"
if item_model in model_names: if item_model in model_names:
data["embedding_available"] = True data["embedding_available"] = True
else: else:
data["embedding_available"] = True data["embedding_available"] = True


# force update search method to keyword_search if indexing_technique is economic # force update search method to keyword_search if indexing_technique is economic
data["retrieval_model_dict"]["search_method"] = "keyword_search"
retrieval_model_dict = data.get("retrieval_model_dict")
if retrieval_model_dict:
retrieval_model_dict["search_method"] = "keyword_search"


if data.get("permission") == "partial_members": if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data = request.get_json() data = request.get_json()


# check embedding model setting # check embedding model setting
if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"):
DatasetService.check_embedding_model_setting(
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
)
embedding_model_provider = data.get("embedding_model_provider")
embedding_model = data.get("embedding_model")
if data.get("indexing_technique") == "high_quality" or embedding_model_provider:
if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(
dataset.tenant_id, embedding_model_provider, embedding_model
)

retrieval_model = data.get("retrieval_model")
if ( if (
data.get("retrieval_model")
and data.get("retrieval_model").get("reranking_model")
and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
retrieval_model
and retrieval_model.get("reranking_model")
and retrieval_model.get("reranking_model").get("reranking_provider_name")
): ):
DatasetService.check_reranking_model_setting( DatasetService.check_reranking_model_setting(
dataset.tenant_id, dataset.tenant_id,
data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
data.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
retrieval_model.get("reranking_model").get("reranking_provider_name"),
retrieval_model.get("reranking_model").get("reranking_model_name"),
) )


# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")


result_data = marshal(dataset, dataset_detail_fields)
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id




args = tag_update_parser.parse_args() args = tag_update_parser.parse_args()
args["type"] = "knowledge" args["type"] = "knowledge"
tag = TagService.update_tags(args, args.get("tag_id"))
tag_id = args["tag_id"]
tag = TagService.update_tags(args, tag_id)


binding_count = TagService.get_tag_binding_count(args.get("tag_id"))
binding_count = TagService.get_tag_binding_count(tag_id)


response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}


if not current_user.has_edit_permission: if not current_user.has_edit_permission:
raise Forbidden() raise Forbidden()
args = tag_delete_parser.parse_args() args = tag_delete_parser.parse_args()
TagService.delete_tag(args.get("tag_id"))
TagService.delete_tag(args["tag_id"])


return 204 return 204



+ 17
- 14
api/controllers/service_api/dataset/document.py Просмотреть файл

if text is None or name is None: if text is None or name is None:
raise ValueError("Both 'text' and 'name' must be non-null values.") raise ValueError("Both 'text' and 'name' must be non-null values.")


if args.get("embedding_model_provider"):
DatasetService.check_embedding_model_setting(
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
)
embedding_model_provider = args.get("embedding_model_provider")
embedding_model = args.get("embedding_model")
if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)

retrieval_model = args.get("retrieval_model")
if ( if (
args.get("retrieval_model")
and args.get("retrieval_model").get("reranking_model")
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
retrieval_model
and retrieval_model.get("reranking_model")
and retrieval_model.get("reranking_model").get("reranking_provider_name")
): ):
DatasetService.check_reranking_model_setting( DatasetService.check_reranking_model_setting(
tenant_id, tenant_id,
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
retrieval_model.get("reranking_model").get("reranking_provider_name"),
retrieval_model.get("reranking_model").get("reranking_model_name"),
) )


if not current_user: if not current_user:
if not dataset: if not dataset:
raise ValueError("Dataset does not exist.") raise ValueError("Dataset does not exist.")


retrieval_model = args.get("retrieval_model")
if ( if (
args.get("retrieval_model")
and args.get("retrieval_model").get("reranking_model")
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
retrieval_model
and retrieval_model.get("reranking_model")
and retrieval_model.get("reranking_model").get("reranking_provider_name")
): ):
DatasetService.check_reranking_model_setting( DatasetService.check_reranking_model_setting(
tenant_id, tenant_id,
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
retrieval_model.get("reranking_model").get("reranking_provider_name"),
retrieval_model.get("reranking_model").get("reranking_model_name"),
) )


# indexing_technique is already set in dataset since this is an update # indexing_technique is already set in dataset since this is an update

+ 1
- 1
api/controllers/service_api/dataset/metadata.py Просмотреть файл

raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)


metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args["name"])
return marshal(metadata, dataset_metadata_fields), 200 return marshal(metadata, dataset_metadata_fields), 200


@service_api_ns.doc("delete_dataset_metadata") @service_api_ns.doc("delete_dataset_metadata")

+ 0
- 1
api/pyrightconfig.json Просмотреть файл

"extensions", "extensions",
"libs", "libs",
"controllers/console/datasets", "controllers/console/datasets",
"controllers/service_api/dataset",
"core/ops", "core/ops",
"core/tools", "core/tools",
"core/model_runtime", "core/model_runtime",

Загрузка…
Отмена
Сохранить