ソースを参照

Fix issue of `ask` API. (#5400)

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
tags/v0.17.0
Kevin Hu 8ヶ月前
コミット
fa76974e24
コミッターのメールアドレスに関連付けられたアカウントが存在しません
3個のファイルの変更21行の追加15行の削除
  1. 3
    13
      api/db/services/conversation_service.py
  2. 2
    2
      api/db/services/dialog_service.py
  3. 16
    0
      rag/prompts.py

+ 3
- 13
api/db/services/conversation_service.py ファイルの表示

from api.utils import get_uuid from api.utils import get_uuid
import json import json


from rag.prompts import chunks_format



class ConversationService(CommonService): class ConversationService(CommonService):
model = Conversation model = Conversation
reference = {} reference = {}
ans["reference"] = {} ans["reference"] = {}


def get_value(d, k1, k2):
return d.get(k1, d.get(k2))

chunk_list = [{
"id": get_value(chunk, "chunk_id", "id"),
"content": get_value(chunk, "content", "content_with_weight"),
"document_id": get_value(chunk, "doc_id", "document_id"),
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
"image_id": get_value(chunk, "image_id", "img_id"),
"positions": get_value(chunk, "positions", "position_int"),
"url": chunk.get("url")
} for chunk in reference.get("chunks", [])]
chunk_list = chunks_format(reference)


reference["chunks"] = chunk_list reference["chunks"] = chunk_list
ans["id"] = message_id ans["id"] = message_id

+ 2
- 2
api/db/services/dialog_service.py ファイルの表示

from rag.app.resume import forbidden_select_fields4resume from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp.search import index_name from rag.nlp.search import index_name
from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question
from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question, chunks_format
from rag.utils import rmSpace, num_tokens_from_string from rag.utils import rmSpace, num_tokens_from_string
from rag.utils.tavily_conn import Tavily from rag.utils.tavily_conn import Tavily




if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
return {"answer": answer, "reference": refs}
return {"answer": answer, "reference": chunks_format(refs)}


answer = "" answer = ""
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}): for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):

+ 16
- 0
rag/prompts.py ファイルの表示

from rag.utils import num_tokens_from_string, encoder from rag.utils import num_tokens_from_string, encoder




def chunks_format(reference):
def get_value(d, k1, k2):
return d.get(k1, d.get(k2))

return [{
"id": get_value(chunk, "chunk_id", "id"),
"content": get_value(chunk, "content", "content_with_weight"),
"document_id": get_value(chunk, "doc_id", "document_id"),
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
"image_id": get_value(chunk, "image_id", "img_id"),
"positions": get_value(chunk, "positions", "position_int"),
"url": chunk.get("url")
} for chunk in reference.get("chunks", [])]


def llm_id2llm_type(llm_id): def llm_id2llm_type(llm_id):
llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id) llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
fnm = os.path.join(get_project_base_directory(), "conf") fnm = os.path.join(get_project_base_directory(), "conf")

読み込み中…
キャンセル
保存