Browse Source

Feat/api validate model provider (#21582)

Co-authored-by: crazywoola <427733928@qq.com>
tags/1.5.1
Khoa 4 months ago
parent
commit
a06af88b26
No account linked to committer's email address

+ 27
- 1
api/controllers/service_api/dataset/dataset.py View File

parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")


args = parser.parse_args() args = parser.parse_args()

if args.get("embedding_model_provider"):
DatasetService.check_embedding_model_setting(
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
)
if (
args.get("retrieval_model")
and args.get("retrieval_model").get("reranking_model")
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
):
DatasetService.check_reranking_model_setting(
tenant_id,
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
)

try: try:
dataset = DatasetService.create_empty_dataset( dataset = DatasetService.create_empty_dataset(
tenant_id=tenant_id, tenant_id=tenant_id,
data = request.get_json() data = request.get_json()


# check embedding model setting # check embedding model setting
if data.get("indexing_technique") == "high_quality":
if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"):
DatasetService.check_embedding_model_setting( DatasetService.check_embedding_model_setting(
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
) )
if (
data.get("retrieval_model")
and data.get("retrieval_model").get("reranking_model")
and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
):
DatasetService.check_reranking_model_setting(
dataset.tenant_id,
data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
data.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
)


# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission( DatasetPermissionService.check_permission(

+ 43
- 1
api/controllers/service_api/dataset/document.py View File

from fields.document_fields import document_fields, document_status_fields from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
from services.dataset_service import DocumentService
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from services.file_service import FileService from services.file_service import FileService


parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")


args = parser.parse_args() args = parser.parse_args()

dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if text is None or name is None: if text is None or name is None:
raise ValueError("Both 'text' and 'name' must be non-null values.") raise ValueError("Both 'text' and 'name' must be non-null values.")


if args.get("embedding_model_provider"):
DatasetService.check_embedding_model_setting(
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
)
if (
args.get("retrieval_model")
and args.get("retrieval_model").get("reranking_model")
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
):
DatasetService.check_reranking_model_setting(
tenant_id,
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
)

upload_file = FileService.upload_text(text=str(text), text_name=str(name)) upload_file = FileService.upload_text(text=str(text), text_name=str(name))
data_source = { data_source = {
"type": "upload_file", "type": "upload_file",
if not dataset: if not dataset:
raise ValueError("Dataset does not exist.") raise ValueError("Dataset does not exist.")


if (
args.get("retrieval_model")
and args.get("retrieval_model").get("reranking_model")
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
):
DatasetService.check_reranking_model_setting(
tenant_id,
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
)

# indexing_technique is already set in dataset since this is an update # indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique args["indexing_technique"] = dataset.indexing_technique


raise ValueError("indexing_technique is required.") raise ValueError("indexing_technique is required.")
args["indexing_technique"] = indexing_technique args["indexing_technique"] = indexing_technique


if "embedding_model_provider" in args:
DatasetService.check_embedding_model_setting(
tenant_id, args["embedding_model_provider"], args["embedding_model"]
)
if (
"retrieval_model" in args
and args["retrieval_model"].get("reranking_model")
and args["retrieval_model"].get("reranking_model").get("reranking_provider_name")
):
DatasetService.check_reranking_model_setting(
tenant_id,
args["retrieval_model"].get("reranking_model").get("reranking_provider_name"),
args["retrieval_model"].get("reranking_model").get("reranking_model_name"),
)

# save file info # save file info
file = request.files["file"] file = request.files["file"]
# check file # check file

+ 17
- 0
api/services/dataset_service.py View File

except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ValueError(ex.description) raise ValueError(ex.description)


@staticmethod
def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=tenant_id,
provider=reranking_model_provider,
model_type=ModelType.RERANK,
model=reranking_model,
)
except LLMBadRequestError:
raise ValueError(
"No Rerank Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)

@staticmethod @staticmethod
def update_dataset(dataset_id, data, user): def update_dataset(dataset_id, data, user):
""" """

Loading…
Cancel
Save