浏览代码

Feat: support cross-lang search. (#7557)

### What problem does this PR solve?

#7376
#4503
#5710 
#7470

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
tags/v0.19.0
Kevin Hu 5 个月前
父节点
当前提交
2ccec93d71
没有帐户链接到提交者的电子邮件
共有 4 个文件被更改,包括 65 次插入3 次删除
  1. 5
    1
      api/apps/chunk_app.py
  2. 5
    1
      api/db/services/dialog_service.py
  3. 1
    1
      api/db/services/document_service.py
  4. 54
    0
      rag/prompts.py

+ 5
- 1
api/apps/chunk_app.py 查看文件

@@ -22,7 +22,7 @@ from flask_login import login_required, current_user
from rag.app.qa import rmPrefix, beAdoc
from rag.app.tag import label_question
from rag.nlp import search, rag_tokenizer
from rag.prompts import keyword_extraction
from rag.prompts import keyword_extraction, cross_languages
from rag.settings import PAGERANK_FLD
from rag.utils import rmSpace
from api.db import LLMType, ParserType
@@ -275,6 +275,7 @@ def retrieval_test():
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
use_kg = req.get("use_kg", False)
top = int(req.get("top_k", 1024))
langs = req.get("cross_languages", [])
tenant_ids = []

try:
@@ -294,6 +295,9 @@ def retrieval_test():
if not e:
return get_data_error_result(message="Knowledgebase not found!")

if langs:
question = cross_languages(kb.tenant_id, None, question, langs)

embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)

rerank_mdl = None

+ 5
- 1
api/db/services/dialog_service.py 查看文件

@@ -36,7 +36,8 @@ from api.utils import current_timestamp, datetime_format
from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question
from rag.nlp.search import index_name
from rag.prompts import chunks_format, citation_prompt, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in
from rag.prompts import chunks_format, citation_prompt, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in, \
cross_languages
from rag.utils import num_tokens_from_string, rmSpace
from rag.utils.tavily_conn import Tavily

@@ -214,6 +215,9 @@ def chat(dialog, messages, stream=True, **kwargs):
else:
questions = questions[-1:]

if prompt_config.get("cross_languages"):
questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]

refine_question_ts = timer()

rerank_mdl = None

+ 1
- 1
api/db/services/document_service.py 查看文件

@@ -131,7 +131,7 @@ class DocumentService(CommonService):
if types:
query = query.where(cls.model.type.in_(types))

return query.scalar() or 0
return int(query.scalar()) or 0

@classmethod
@DB.connection_context()

+ 54
- 0
rag/prompts.py 查看文件

@@ -306,6 +306,60 @@ Output: What's the weather in Rochester on {tomorrow}?
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]

def cross_languages(tenant_id, llm_id, query, languages=[]):
from api.db.services.llm_service import LLMBundle

if llm_id and llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
else:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)

sys_prompt = """
Act as a streamlined multilingual translator. Strictly output translations separated by ### without any explanations or formatting. Follow these rules:

1. Accept batch translation requests in format:
[source text]
===
[target languages separated by commas]

2. Always maintain:
- Original formatting (tables/lists/spacing)
- Technical terminology accuracy
- Cultural context appropriateness

3. Output format:
[language1 translation]
###
[language1 translation]

**Examples:**
Input:
Hello World! Let's discuss AI safety.
===
Chinese, French, Jappanese

Output:
你好世界!让我们讨论人工智能安全问题。
###
Bonjour le monde ! Parlons de la sécurité de l'IA.
###
こんにちは世界!AIの安全性について話し合いましょう。
"""
user_prompt=f"""
Input:
{query}
===
{', '.join(languages)}

Output:
"""

ans = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.2})
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
if ans.find("**ERROR**") >= 0:
return query
return "\n".join([a for a in re.sub(r"(^Output:|\n+)", "", ans, flags=re.DOTALL).split("===") if a.strip()])


def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
prompt = f"""

正在加载...
取消
保存