### What problem does this PR solve? #7376 #4503 #5710 #7470 ### Type of change - [x] New Feature (non-breaking change which adds functionality)tags/v0.19.0
| @@ -22,7 +22,7 @@ from flask_login import login_required, current_user | |||
| from rag.app.qa import rmPrefix, beAdoc | |||
| from rag.app.tag import label_question | |||
| from rag.nlp import search, rag_tokenizer | |||
| from rag.prompts import keyword_extraction | |||
| from rag.prompts import keyword_extraction, cross_languages | |||
| from rag.settings import PAGERANK_FLD | |||
| from rag.utils import rmSpace | |||
| from api.db import LLMType, ParserType | |||
| @@ -275,6 +275,7 @@ def retrieval_test(): | |||
| vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) | |||
| use_kg = req.get("use_kg", False) | |||
| top = int(req.get("top_k", 1024)) | |||
| langs = req.get("cross_languages", []) | |||
| tenant_ids = [] | |||
| try: | |||
| @@ -294,6 +295,9 @@ def retrieval_test(): | |||
| if not e: | |||
| return get_data_error_result(message="Knowledgebase not found!") | |||
| if langs: | |||
| question = cross_languages(kb.tenant_id, None, question, langs) | |||
| embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) | |||
| rerank_mdl = None | |||
| @@ -36,7 +36,8 @@ from api.utils import current_timestamp, datetime_format | |||
| from rag.app.resume import forbidden_select_fields4resume | |||
| from rag.app.tag import label_question | |||
| from rag.nlp.search import index_name | |||
| from rag.prompts import chunks_format, citation_prompt, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in | |||
| from rag.prompts import chunks_format, citation_prompt, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in, \ | |||
| cross_languages | |||
| from rag.utils import num_tokens_from_string, rmSpace | |||
| from rag.utils.tavily_conn import Tavily | |||
| @@ -214,6 +215,9 @@ def chat(dialog, messages, stream=True, **kwargs): | |||
| else: | |||
| questions = questions[-1:] | |||
| if prompt_config.get("cross_languages"): | |||
| questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])] | |||
| refine_question_ts = timer() | |||
| rerank_mdl = None | |||
| @@ -131,7 +131,7 @@ class DocumentService(CommonService): | |||
| if types: | |||
| query = query.where(cls.model.type.in_(types)) | |||
| return query.scalar() or 0 | |||
| return int(query.scalar()) or 0 | |||
| @classmethod | |||
| @DB.connection_context() | |||
| @@ -306,6 +306,60 @@ Output: What's the weather in Rochester on {tomorrow}? | |||
| ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) | |||
| return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"] | |||
| def cross_languages(tenant_id, llm_id, query, languages=[]): | |||
| from api.db.services.llm_service import LLMBundle | |||
| if llm_id and llm_id2llm_type(llm_id) == "image2text": | |||
| chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) | |||
| else: | |||
| chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id) | |||
| sys_prompt = """ | |||
| Act as a streamlined multilingual translator. Strictly output translations separated by ### without any explanations or formatting. Follow these rules: | |||
| 1. Accept batch translation requests in format: | |||
| [source text] | |||
| === | |||
| [target languages separated by commas] | |||
| 2. Always maintain: | |||
| - Original formatting (tables/lists/spacing) | |||
| - Technical terminology accuracy | |||
| - Cultural context appropriateness | |||
| 3. Output format: | |||
| [language1 translation] | |||
| ### | |||
| [language1 translation] | |||
| **Examples:** | |||
| Input: | |||
| Hello World! Let's discuss AI safety. | |||
| === | |||
| Chinese, French, Jappanese | |||
| Output: | |||
| 你好世界!让我们讨论人工智能安全问题。 | |||
| ### | |||
| Bonjour le monde ! Parlons de la sécurité de l'IA. | |||
| ### | |||
| こんにちは世界!AIの安全性について話し合いましょう。 | |||
| """ | |||
| user_prompt=f""" | |||
| Input: | |||
| {query} | |||
| === | |||
| {', '.join(languages)} | |||
| Output: | |||
| """ | |||
| ans = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.2}) | |||
| ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) | |||
| if ans.find("**ERROR**") >= 0: | |||
| return query | |||
| return "\n".join([a for a in re.sub(r"(^Output:|\n+)", "", ans, flags=re.DOTALL).split("===") if a.strip()]) | |||
| def content_tagging(chat_mdl, content, all_tags, examples, topn=3): | |||
| prompt = f""" | |||