|
|
|
@@ -31,9 +31,7 @@ from api.utils.api_utils import get_result |
|
|
|
@token_required |
|
|
|
def create(tenant_id): |
|
|
|
req = request.json |
|
|
|
ids = req.get("dataset_ids") |
|
|
|
if not ids: |
|
|
|
return get_error_data_result(message="`dataset_ids` is required") |
|
|
|
ids = [i for i in req.get("dataset_ids", []) if i] |
|
|
|
for kb_id in ids: |
|
|
|
kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id) |
|
|
|
if not kbs: |
|
|
|
@@ -42,10 +40,10 @@ def create(tenant_id): |
|
|
|
kb = kbs[0] |
|
|
|
if kb.chunk_num == 0: |
|
|
|
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") |
|
|
|
kbs = KnowledgebaseService.get_by_ids(ids) |
|
|
|
kbs = KnowledgebaseService.get_by_ids(ids) if ids else [] |
|
|
|
embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison |
|
|
|
embd_count = list(set(embd_ids)) |
|
|
|
if len(embd_count) != 1: |
|
|
|
if len(embd_count) > 1: |
|
|
|
return get_result(message='Datasets use different embedding models."', |
|
|
|
code=settings.RetCode.AUTHENTICATION_ERROR) |
|
|
|
req["kb_ids"] = ids |