Просмотр исходного кода

Fix: Normalize embedding model ID comparison across datasets (#5169)

Modify embedding model ID comparison to remove vendor suffixes, ensuring
consistent model identification when working with multiple knowledge
bases. This change affects dialog creation, chat operations, and
document retrieval test functions.

### What problem does this PR solve?

resolve this bug: https://github.com/infiniflow/ragflow/issues/5166

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: wenju.li <wenju.li@deepctr.cn>
tags/v0.17.0
liwenju0 8 месяцев назад
Родитель
Сommit
f298e55ded
Аккаунт пользователя с таким Email не найден
3 измененных файлов: 8 добавлений и 4 удалений
  1. 3
    1
      api/apps/dialog_app.py
  2. 4
    2
      api/apps/sdk/chat.py
  3. 1
    1
      api/apps/sdk/doc.py

+ 3
- 1
api/apps/dialog_app.py Просмотреть файл

@@ -18,6 +18,7 @@ from flask import request
from flask_login import login_required, current_user
from api.db.services.dialog_service import DialogService
from api.db import StatusEnum
from api.db.services.llm_service import TenantLLMService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import TenantService, UserTenantService
from api import settings
@@ -75,7 +76,8 @@ def set_dialog():
if not e:
return get_data_error_result(message="Tenant not found!")
kbs = KnowledgebaseService.get_by_ids(req.get("kb_ids"))
embd_count = len(set([kb.embd_id for kb in kbs]))
embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison
embd_count = len(set(embd_ids))
if embd_count != 1:
return get_data_error_result(message=f'Datasets use different embedding models: {[kb.embd_id for kb in kbs]}"')


+ 4
- 2
api/apps/sdk/chat.py Просмотреть файл

@@ -41,7 +41,8 @@ def create(tenant_id):
if kb.chunk_num == 0:
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
kbs = KnowledgebaseService.get_by_ids(ids)
embd_count = list(set([kb.embd_id for kb in kbs]))
embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison
embd_count = list(set(embd_ids))
if len(embd_count) != 1:
return get_result(message='Datasets use different embedding models."',
code=settings.RetCode.AUTHENTICATION_ERROR)
@@ -176,7 +177,8 @@ def update(tenant_id, chat_id):
if kb.chunk_num == 0:
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
kbs = KnowledgebaseService.get_by_ids(ids)
embd_count = list(set([kb.embd_id for kb in kbs]))
embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison
embd_count = list(set(embd_ids))
if len(embd_count) != 1:
return get_result(
message='Datasets use different embedding models."',

+ 1
- 1
api/apps/sdk/doc.py Просмотреть файл

@@ -1305,7 +1305,7 @@ def retrieval_test(tenant_id):
if not KnowledgebaseService.accessible(kb_id=id, user_id=tenant_id):
return get_error_data_result(f"You don't own the dataset {id}.")
kbs = KnowledgebaseService.get_by_ids(kb_ids)
embd_nms = list(set([kb.embd_id for kb in kbs]))
embd_nms = list(set([TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs])) # remove vendor suffix for comparison
if len(embd_nms) != 1:
return get_result(
message='Datasets use different embedding models."',

Загрузка…
Отмена
Сохранить