Explorar el Código

Fix issue of `ask` API. (#5400)

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
tags/v0.17.0
Kevin Hu hace 8 meses
padre
commit
fa76974e24
No account linked to committer's email address
Se han modificado 3 ficheros con 21 adiciones y 15 borrados
  1. 3
    13
      api/db/services/conversation_service.py
  2. 2
    2
      api/db/services/dialog_service.py
  3. 16
    0
      rag/prompts.py

+ 3
- 13
api/db/services/conversation_service.py Ver fichero

@@ -23,6 +23,8 @@ from api.db.services.dialog_service import DialogService, chat
from api.utils import get_uuid
import json

from rag.prompts import chunks_format


class ConversationService(CommonService):
model = Conversation
@@ -53,19 +55,7 @@ def structure_answer(conv, ans, message_id, session_id):
reference = {}
ans["reference"] = {}

def get_value(d, k1, k2):
return d.get(k1, d.get(k2))

chunk_list = [{
"id": get_value(chunk, "chunk_id", "id"),
"content": get_value(chunk, "content", "content_with_weight"),
"document_id": get_value(chunk, "doc_id", "document_id"),
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
"image_id": get_value(chunk, "image_id", "img_id"),
"positions": get_value(chunk, "positions", "position_int"),
"url": chunk.get("url")
} for chunk in reference.get("chunks", [])]
chunk_list = chunks_format(reference)

reference["chunks"] = chunk_list
ans["id"] = message_id

+ 2
- 2
api/db/services/dialog_service.py Ver fichero

@@ -30,7 +30,7 @@ from api import settings
from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question
from rag.nlp.search import index_name
from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question
from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question, chunks_format
from rag.utils import rmSpace, num_tokens_from_string
from rag.utils.tavily_conn import Tavily

@@ -511,7 +511,7 @@ def ask(question, kb_ids, tenant_id):

if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
return {"answer": answer, "reference": refs}
return {"answer": answer, "reference": chunks_format(refs)}

answer = ""
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):

+ 16
- 0
rag/prompts.py Ver fichero

@@ -28,6 +28,22 @@ from rag.settings import TAG_FLD
from rag.utils import num_tokens_from_string, encoder


def chunks_format(reference):
def get_value(d, k1, k2):
return d.get(k1, d.get(k2))

return [{
"id": get_value(chunk, "chunk_id", "id"),
"content": get_value(chunk, "content", "content_with_weight"),
"document_id": get_value(chunk, "doc_id", "document_id"),
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
"image_id": get_value(chunk, "image_id", "img_id"),
"positions": get_value(chunk, "positions", "position_int"),
"url": chunk.get("url")
} for chunk in reference.get("chunks", [])]


def llm_id2llm_type(llm_id):
llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
fnm = os.path.join(get_project_base_directory(), "conf")

Cargando…
Cancelar
Guardar