Sfoglia il codice sorgente

fix: improve embedding model validation logic for dataset operations (#3235)

What problem does this PR solve?
When creating or updating datasets with custom embedding models (e.g.,
Ollama), the validation logic was too restrictive and prevented valid
models from being used. The previous implementation would reject valid
custom models if they weren't in the predefined list, even when they
existed in TenantLLMService.

Changes:
- Simplify and improve the embedding model validation flow in
create/update endpoints
- Check TenantLLMService for custom models before rejecting
- Make validation logic more consistent between create and update
operations

### What problem does this PR solve?

This fix allows users to successfully create and update datasets with
custom embedding models while maintaining proper validation checks.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Jin Hai <haijin.chn@gmail.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
Co-authored-by: liuhua <10215101452@stu.ecnu.edu.cn>
tags/v0.14.0
Mohammed Tawileh 11 mesi fa
parent
commit
5038552ed9
Nessun account collegato all'indirizzo email del committer
1 ha cambiato i file con 11 aggiunte e 22 eliminazioni
  1. 11
    22
      api/apps/sdk/dataset.py

+ 11
- 22
api/apps/sdk/dataset.py Vedi File

@@ -159,21 +159,15 @@ def create(tenant_id):
embd_model = LLMService.query(
llm_name=req["embedding_model"], model_type="embedding"
)
if embd_model:
if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding",llm_name=req.get("embedding_model"),):
return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist")
if not embd_model:
embd_model=TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model"))
if not embd_model:
return get_error_data_result(
f"`embedding_model` {req.get('embedding_model')} doesn't exist"
)
if embd_model:
if req[
"embedding_model"
] not in valid_embedding_models and not TenantLLMService.query(
tenant_id=tenant_id,
model_type="embedding",
llm_name=req.get("embedding_model"),
):
return get_error_data_result(
f"`embedding_model` {req.get('embedding_model')} doesn't exist"
)
key_mapping = {
"chunk_num": "chunk_count",
"doc_num": "document_count",
@@ -403,21 +397,16 @@ def update(tenant_id, dataset_id):
embd_model = LLMService.query(
llm_name=req["embedding_model"], model_type="embedding"
)
if embd_model:
if req["embedding_model"] not in valid_embedding_models and not TenantLLMService.query(tenant_id=tenant_id,model_type="embedding",llm_name=req.get("embedding_model"),):
return get_error_data_result(f"`embedding_model` {req.get('embedding_model')} doesn't exist")
if not embd_model:
embd_model=TenantLLMService.query(tenant_id=tenant_id,model_type="embedding", llm_name=req.get("embedding_model"))

if not embd_model:
return get_error_data_result(
f"`embedding_model` {req.get('embedding_model')} doesn't exist"
)
if embd_model:
if req[
"embedding_model"
] not in valid_embedding_models and not TenantLLMService.query(
tenant_id=tenant_id,
model_type="embedding",
llm_name=req.get("embedding_model"),
):
return get_error_data_result(
f"`embedding_model` {req.get('embedding_model')} doesn't exist"
)
req["embd_id"] = req.pop("embedding_model")
if "name" in req:
req["name"] = req["name"].strip()

Loading…
Annulla
Salva