Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

dataset_retrieval.py 53KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218
  1. import json
  2. import math
  3. import re
  4. import threading
  5. from collections import Counter, defaultdict
  6. from collections.abc import Generator, Mapping
  7. from typing import Any, Optional, Union, cast
  8. from flask import Flask, current_app
  9. from sqlalchemy import Float, and_, or_, select, text
  10. from sqlalchemy import cast as sqlalchemy_cast
  11. from sqlalchemy.orm import Session
  12. from core.app.app_config.entities import (
  13. DatasetEntity,
  14. DatasetRetrieveConfigEntity,
  15. MetadataFilteringCondition,
  16. ModelConfig,
  17. )
  18. from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
  19. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  20. from core.entities.agent_entities import PlanningStrategy
  21. from core.entities.model_entities import ModelStatus
  22. from core.memory.token_buffer_memory import TokenBufferMemory
  23. from core.model_manager import ModelInstance, ModelManager
  24. from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
  25. from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
  26. from core.model_runtime.entities.model_entities import ModelFeature, ModelType
  27. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  28. from core.ops.entities.trace_entity import TraceTaskName
  29. from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
  30. from core.ops.utils import measure_time
  31. from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
  32. from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
  33. from core.prompt.simple_prompt_transform import ModelMode
  34. from core.rag.data_post_processor.data_post_processor import DataPostProcessor
  35. from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
  36. from core.rag.datasource.retrieval_service import RetrievalService
  37. from core.rag.entities.citation_metadata import RetrievalSourceMetadata
  38. from core.rag.entities.context_entities import DocumentContext
  39. from core.rag.entities.metadata_entities import Condition, MetadataCondition
  40. from core.rag.index_processor.constant.index_type import IndexType
  41. from core.rag.models.document import Document
  42. from core.rag.rerank.rerank_type import RerankMode
  43. from core.rag.retrieval.retrieval_methods import RetrievalMethod
  44. from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
  45. from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
  46. from core.rag.retrieval.template_prompts import (
  47. METADATA_FILTER_ASSISTANT_PROMPT_1,
  48. METADATA_FILTER_ASSISTANT_PROMPT_2,
  49. METADATA_FILTER_COMPLETION_PROMPT,
  50. METADATA_FILTER_SYSTEM_PROMPT,
  51. METADATA_FILTER_USER_PROMPT_1,
  52. METADATA_FILTER_USER_PROMPT_2,
  53. METADATA_FILTER_USER_PROMPT_3,
  54. )
  55. from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
  56. from extensions.ext_database import db
  57. from libs.json_in_md_parser import parse_and_check_json_markdown
  58. from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment
  59. from models.dataset import Document as DatasetDocument
  60. from services.external_knowledge_service import ExternalDatasetService
  61. default_retrieval_model: dict[str, Any] = {
  62. "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
  63. "reranking_enable": False,
  64. "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
  65. "top_k": 4,
  66. "score_threshold_enabled": False,
  67. }
  68. class DatasetRetrieval:
  69. def __init__(self, application_generate_entity=None):
  70. self.application_generate_entity = application_generate_entity
  71. def retrieve(
  72. self,
  73. app_id: str,
  74. user_id: str,
  75. tenant_id: str,
  76. model_config: ModelConfigWithCredentialsEntity,
  77. config: DatasetEntity,
  78. query: str,
  79. invoke_from: InvokeFrom,
  80. show_retrieve_source: bool,
  81. hit_callback: DatasetIndexToolCallbackHandler,
  82. message_id: str,
  83. memory: Optional[TokenBufferMemory] = None,
  84. inputs: Optional[Mapping[str, Any]] = None,
  85. ) -> Optional[str]:
  86. """
  87. Retrieve dataset.
  88. :param app_id: app_id
  89. :param user_id: user_id
  90. :param tenant_id: tenant id
  91. :param model_config: model config
  92. :param config: dataset config
  93. :param query: query
  94. :param invoke_from: invoke from
  95. :param show_retrieve_source: show retrieve source
  96. :param hit_callback: hit callback
  97. :param message_id: message id
  98. :param memory: memory
  99. :param inputs: inputs
  100. :return:
  101. """
  102. dataset_ids = config.dataset_ids
  103. if len(dataset_ids) == 0:
  104. return None
  105. retrieve_config = config.retrieve_config
  106. # check model is support tool calling
  107. model_type_instance = model_config.provider_model_bundle.model_type_instance
  108. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  109. model_manager = ModelManager()
  110. model_instance = model_manager.get_model_instance(
  111. tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model
  112. )
  113. # get model schema
  114. model_schema = model_type_instance.get_model_schema(
  115. model=model_config.model, credentials=model_config.credentials
  116. )
  117. if not model_schema:
  118. return None
  119. planning_strategy = PlanningStrategy.REACT_ROUTER
  120. features = model_schema.features
  121. if features:
  122. if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
  123. planning_strategy = PlanningStrategy.ROUTER
  124. available_datasets = []
  125. for dataset_id in dataset_ids:
  126. # get dataset from dataset id
  127. dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
  128. dataset = db.session.scalar(dataset_stmt)
  129. # pass if dataset is not available
  130. if not dataset:
  131. continue
  132. # pass if dataset is not available
  133. if dataset and dataset.available_document_count == 0 and dataset.provider != "external":
  134. continue
  135. available_datasets.append(dataset)
  136. if inputs:
  137. inputs = {key: str(value) for key, value in inputs.items()}
  138. else:
  139. inputs = {}
  140. available_datasets_ids = [dataset.id for dataset in available_datasets]
  141. metadata_filter_document_ids, metadata_condition = self.get_metadata_filter_condition(
  142. available_datasets_ids,
  143. query,
  144. tenant_id,
  145. user_id,
  146. retrieve_config.metadata_filtering_mode, # type: ignore
  147. retrieve_config.metadata_model_config, # type: ignore
  148. retrieve_config.metadata_filtering_conditions,
  149. inputs,
  150. )
  151. all_documents = []
  152. user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
  153. if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
  154. all_documents = self.single_retrieve(
  155. app_id,
  156. tenant_id,
  157. user_id,
  158. user_from,
  159. available_datasets,
  160. query,
  161. model_instance,
  162. model_config,
  163. planning_strategy,
  164. message_id,
  165. metadata_filter_document_ids,
  166. metadata_condition,
  167. )
  168. elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
  169. all_documents = self.multiple_retrieve(
  170. app_id,
  171. tenant_id,
  172. user_id,
  173. user_from,
  174. available_datasets,
  175. query,
  176. retrieve_config.top_k or 0,
  177. retrieve_config.score_threshold or 0,
  178. retrieve_config.rerank_mode or "reranking_model",
  179. retrieve_config.reranking_model,
  180. retrieve_config.weights,
  181. True if retrieve_config.reranking_enabled is None else retrieve_config.reranking_enabled,
  182. message_id,
  183. metadata_filter_document_ids,
  184. metadata_condition,
  185. )
  186. dify_documents = [item for item in all_documents if item.provider == "dify"]
  187. external_documents = [item for item in all_documents if item.provider == "external"]
  188. document_context_list: list[DocumentContext] = []
  189. retrieval_resource_list: list[RetrievalSourceMetadata] = []
  190. # deal with external documents
  191. for item in external_documents:
  192. document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score")))
  193. source = RetrievalSourceMetadata(
  194. dataset_id=item.metadata.get("dataset_id"),
  195. dataset_name=item.metadata.get("dataset_name"),
  196. document_id=item.metadata.get("document_id") or item.metadata.get("title"),
  197. document_name=item.metadata.get("title"),
  198. data_source_type="external",
  199. retriever_from=invoke_from.to_source(),
  200. score=item.metadata.get("score"),
  201. content=item.page_content,
  202. )
  203. retrieval_resource_list.append(source)
  204. # deal with dify documents
  205. if dify_documents:
  206. records = RetrievalService.format_retrieval_documents(dify_documents)
  207. if records:
  208. for record in records:
  209. segment = record.segment
  210. if segment.answer:
  211. document_context_list.append(
  212. DocumentContext(
  213. content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
  214. score=record.score,
  215. )
  216. )
  217. else:
  218. document_context_list.append(
  219. DocumentContext(
  220. content=segment.get_sign_content(),
  221. score=record.score,
  222. )
  223. )
  224. if show_retrieve_source:
  225. for record in records:
  226. segment = record.segment
  227. dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
  228. dataset_document_stmt = select(DatasetDocument).where(
  229. DatasetDocument.id == segment.document_id,
  230. DatasetDocument.enabled == True,
  231. DatasetDocument.archived == False,
  232. )
  233. document = db.session.scalar(dataset_document_stmt)
  234. if dataset and document:
  235. source = RetrievalSourceMetadata(
  236. dataset_id=dataset.id,
  237. dataset_name=dataset.name,
  238. document_id=document.id,
  239. document_name=document.name,
  240. data_source_type=document.data_source_type,
  241. segment_id=segment.id,
  242. retriever_from=invoke_from.to_source(),
  243. score=record.score or 0.0,
  244. doc_metadata=document.doc_metadata,
  245. )
  246. if invoke_from.to_source() == "dev":
  247. source.hit_count = segment.hit_count
  248. source.word_count = segment.word_count
  249. source.segment_position = segment.position
  250. source.index_node_hash = segment.index_node_hash
  251. if segment.answer:
  252. source.content = f"question:{segment.content} \nanswer:{segment.answer}"
  253. else:
  254. source.content = segment.content
  255. retrieval_resource_list.append(source)
  256. if hit_callback and retrieval_resource_list:
  257. retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True)
  258. for position, item in enumerate(retrieval_resource_list, start=1):
  259. item.position = position
  260. hit_callback.return_retriever_resource_info(retrieval_resource_list)
  261. if document_context_list:
  262. document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
  263. return str("\n".join([document_context.content for document_context in document_context_list]))
  264. return ""
  265. def single_retrieve(
  266. self,
  267. app_id: str,
  268. tenant_id: str,
  269. user_id: str,
  270. user_from: str,
  271. available_datasets: list,
  272. query: str,
  273. model_instance: ModelInstance,
  274. model_config: ModelConfigWithCredentialsEntity,
  275. planning_strategy: PlanningStrategy,
  276. message_id: Optional[str] = None,
  277. metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
  278. metadata_condition: Optional[MetadataCondition] = None,
  279. ):
  280. tools = []
  281. for dataset in available_datasets:
  282. description = dataset.description
  283. if not description:
  284. description = "useful for when you want to answer queries about the " + dataset.name
  285. description = description.replace("\n", "").replace("\r", "")
  286. message_tool = PromptMessageTool(
  287. name=dataset.id,
  288. description=description,
  289. parameters={
  290. "type": "object",
  291. "properties": {},
  292. "required": [],
  293. },
  294. )
  295. tools.append(message_tool)
  296. dataset_id = None
  297. if planning_strategy == PlanningStrategy.REACT_ROUTER:
  298. react_multi_dataset_router = ReactMultiDatasetRouter()
  299. dataset_id = react_multi_dataset_router.invoke(
  300. query, tools, model_config, model_instance, user_id, tenant_id
  301. )
  302. elif planning_strategy == PlanningStrategy.ROUTER:
  303. function_call_router = FunctionCallMultiDatasetRouter()
  304. dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
  305. if dataset_id:
  306. # get retrieval model config
  307. dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
  308. dataset = db.session.scalar(dataset_stmt)
  309. if dataset:
  310. results = []
  311. if dataset.provider == "external":
  312. external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
  313. tenant_id=dataset.tenant_id,
  314. dataset_id=dataset_id,
  315. query=query,
  316. external_retrieval_parameters=dataset.retrieval_model,
  317. metadata_condition=metadata_condition,
  318. )
  319. for external_document in external_documents:
  320. document = Document(
  321. page_content=external_document.get("content"),
  322. metadata=external_document.get("metadata"),
  323. provider="external",
  324. )
  325. if document.metadata is not None:
  326. document.metadata["score"] = external_document.get("score")
  327. document.metadata["title"] = external_document.get("title")
  328. document.metadata["dataset_id"] = dataset_id
  329. document.metadata["dataset_name"] = dataset.name
  330. results.append(document)
  331. else:
  332. if metadata_condition and not metadata_filter_document_ids:
  333. return []
  334. document_ids_filter = None
  335. if metadata_filter_document_ids:
  336. document_ids = metadata_filter_document_ids.get(dataset.id, [])
  337. if document_ids:
  338. document_ids_filter = document_ids
  339. else:
  340. return []
  341. retrieval_model_config = dataset.retrieval_model or default_retrieval_model
  342. # get top k
  343. top_k = retrieval_model_config["top_k"]
  344. # get retrieval method
  345. if dataset.indexing_technique == "economy":
  346. retrieval_method = "keyword_search"
  347. else:
  348. retrieval_method = retrieval_model_config["search_method"]
  349. # get reranking model
  350. reranking_model = (
  351. retrieval_model_config["reranking_model"]
  352. if retrieval_model_config["reranking_enable"]
  353. else None
  354. )
  355. # get score threshold
  356. score_threshold = 0.0
  357. score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
  358. if score_threshold_enabled:
  359. score_threshold = retrieval_model_config.get("score_threshold", 0.0)
  360. with measure_time() as timer:
  361. results = RetrievalService.retrieve(
  362. retrieval_method=retrieval_method,
  363. dataset_id=dataset.id,
  364. query=query,
  365. top_k=top_k,
  366. score_threshold=score_threshold,
  367. reranking_model=reranking_model,
  368. reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
  369. weights=retrieval_model_config.get("weights", None),
  370. document_ids_filter=document_ids_filter,
  371. )
  372. self._on_query(query, [dataset_id], app_id, user_from, user_id)
  373. if results:
  374. self._on_retrieval_end(results, message_id, timer)
  375. return results
  376. return []
  377. def multiple_retrieve(
  378. self,
  379. app_id: str,
  380. tenant_id: str,
  381. user_id: str,
  382. user_from: str,
  383. available_datasets: list,
  384. query: str,
  385. top_k: int,
  386. score_threshold: float,
  387. reranking_mode: str,
  388. reranking_model: Optional[dict] = None,
  389. weights: Optional[dict[str, Any]] = None,
  390. reranking_enable: bool = True,
  391. message_id: Optional[str] = None,
  392. metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
  393. metadata_condition: Optional[MetadataCondition] = None,
  394. ):
  395. if not available_datasets:
  396. return []
  397. threads = []
  398. all_documents: list[Document] = []
  399. dataset_ids = [dataset.id for dataset in available_datasets]
  400. index_type_check = all(
  401. item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets
  402. )
  403. if not index_type_check and (not reranking_enable or reranking_mode != RerankMode.RERANKING_MODEL):
  404. raise ValueError(
  405. "The configured knowledge base list have different indexing technique, please set reranking model."
  406. )
  407. index_type = available_datasets[0].indexing_technique
  408. if index_type == "high_quality":
  409. embedding_model_check = all(
  410. item.embedding_model == available_datasets[0].embedding_model for item in available_datasets
  411. )
  412. embedding_model_provider_check = all(
  413. item.embedding_model_provider == available_datasets[0].embedding_model_provider
  414. for item in available_datasets
  415. )
  416. if (
  417. reranking_enable
  418. and reranking_mode == "weighted_score"
  419. and (not embedding_model_check or not embedding_model_provider_check)
  420. ):
  421. raise ValueError(
  422. "The configured knowledge base list have different embedding model, please set reranking model."
  423. )
  424. if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE:
  425. if weights is not None:
  426. weights["vector_setting"]["embedding_provider_name"] = available_datasets[
  427. 0
  428. ].embedding_model_provider
  429. weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
  430. for dataset in available_datasets:
  431. index_type = dataset.indexing_technique
  432. document_ids_filter = None
  433. if dataset.provider != "external":
  434. if metadata_condition and not metadata_filter_document_ids:
  435. continue
  436. if metadata_filter_document_ids:
  437. document_ids = metadata_filter_document_ids.get(dataset.id, [])
  438. if document_ids:
  439. document_ids_filter = document_ids
  440. else:
  441. continue
  442. retrieval_thread = threading.Thread(
  443. target=self._retriever,
  444. kwargs={
  445. "flask_app": current_app._get_current_object(), # type: ignore
  446. "dataset_id": dataset.id,
  447. "query": query,
  448. "top_k": top_k,
  449. "all_documents": all_documents,
  450. "document_ids_filter": document_ids_filter,
  451. "metadata_condition": metadata_condition,
  452. },
  453. )
  454. threads.append(retrieval_thread)
  455. retrieval_thread.start()
  456. for thread in threads:
  457. thread.join()
  458. with measure_time() as timer:
  459. if reranking_enable:
  460. # do rerank for searched documents
  461. data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
  462. all_documents = data_post_processor.invoke(
  463. query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
  464. )
  465. else:
  466. if index_type == "economy":
  467. all_documents = self.calculate_keyword_score(query, all_documents, top_k)
  468. elif index_type == "high_quality":
  469. all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
  470. else:
  471. all_documents = all_documents[:top_k] if top_k else all_documents
  472. self._on_query(query, dataset_ids, app_id, user_from, user_id)
  473. if all_documents:
  474. self._on_retrieval_end(all_documents, message_id, timer)
  475. return all_documents
  476. def _on_retrieval_end(
  477. self, documents: list[Document], message_id: Optional[str] = None, timer: Optional[dict] = None
  478. ) -> None:
  479. """Handle retrieval end."""
  480. dify_documents = [document for document in documents if document.provider == "dify"]
  481. for document in dify_documents:
  482. if document.metadata is not None:
  483. dataset_document_stmt = select(DatasetDocument).where(
  484. DatasetDocument.id == document.metadata["document_id"]
  485. )
  486. dataset_document = db.session.scalar(dataset_document_stmt)
  487. if dataset_document:
  488. if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
  489. child_chunk_stmt = select(ChildChunk).where(
  490. ChildChunk.index_node_id == document.metadata["doc_id"],
  491. ChildChunk.dataset_id == dataset_document.dataset_id,
  492. ChildChunk.document_id == dataset_document.id,
  493. )
  494. child_chunk = db.session.scalar(child_chunk_stmt)
  495. if child_chunk:
  496. segment = (
  497. db.session.query(DocumentSegment)
  498. .where(DocumentSegment.id == child_chunk.segment_id)
  499. .update(
  500. {DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
  501. synchronize_session=False,
  502. )
  503. )
  504. db.session.commit()
  505. else:
  506. query = db.session.query(DocumentSegment).where(
  507. DocumentSegment.index_node_id == document.metadata["doc_id"]
  508. )
  509. # if 'dataset_id' in document.metadata:
  510. if "dataset_id" in document.metadata:
  511. query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
  512. # add hit count to document segment
  513. query.update(
  514. {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
  515. )
  516. db.session.commit()
  517. # get tracing instance
  518. trace_manager: TraceQueueManager | None = (
  519. self.application_generate_entity.trace_manager if self.application_generate_entity else None
  520. )
  521. if trace_manager:
  522. trace_manager.add_trace_task(
  523. TraceTask(
  524. TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
  525. )
  526. )
  527. def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str) -> None:
  528. """
  529. Handle query.
  530. """
  531. if not query:
  532. return
  533. dataset_queries = []
  534. for dataset_id in dataset_ids:
  535. dataset_query = DatasetQuery(
  536. dataset_id=dataset_id,
  537. content=query,
  538. source="app",
  539. source_app_id=app_id,
  540. created_by_role=user_from,
  541. created_by=user_id,
  542. )
  543. dataset_queries.append(dataset_query)
  544. if dataset_queries:
  545. db.session.add_all(dataset_queries)
  546. db.session.commit()
  547. def _retriever(
  548. self,
  549. flask_app: Flask,
  550. dataset_id: str,
  551. query: str,
  552. top_k: int,
  553. all_documents: list,
  554. document_ids_filter: Optional[list[str]] = None,
  555. metadata_condition: Optional[MetadataCondition] = None,
  556. ):
  557. with flask_app.app_context():
  558. with Session(db.engine) as session:
  559. dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
  560. dataset = db.session.scalar(dataset_stmt)
  561. if not dataset:
  562. return []
  563. if dataset.provider == "external":
  564. external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
  565. tenant_id=dataset.tenant_id,
  566. dataset_id=dataset_id,
  567. query=query,
  568. external_retrieval_parameters=dataset.retrieval_model,
  569. metadata_condition=metadata_condition,
  570. )
  571. for external_document in external_documents:
  572. document = Document(
  573. page_content=external_document.get("content"),
  574. metadata=external_document.get("metadata"),
  575. provider="external",
  576. )
  577. if document.metadata is not None:
  578. document.metadata["score"] = external_document.get("score")
  579. document.metadata["title"] = external_document.get("title")
  580. document.metadata["dataset_id"] = dataset_id
  581. document.metadata["dataset_name"] = dataset.name
  582. all_documents.append(document)
  583. else:
  584. # get retrieval model , if the model is not setting , using default
  585. retrieval_model = dataset.retrieval_model or default_retrieval_model
  586. if dataset.indexing_technique == "economy":
  587. # use keyword table query
  588. documents = RetrievalService.retrieve(
  589. retrieval_method="keyword_search",
  590. dataset_id=dataset.id,
  591. query=query,
  592. top_k=top_k,
  593. document_ids_filter=document_ids_filter,
  594. )
  595. if documents:
  596. all_documents.extend(documents)
  597. else:
  598. if top_k > 0:
  599. # retrieval source
  600. documents = RetrievalService.retrieve(
  601. retrieval_method=retrieval_model["search_method"],
  602. dataset_id=dataset.id,
  603. query=query,
  604. top_k=retrieval_model.get("top_k") or 4,
  605. score_threshold=retrieval_model.get("score_threshold", 0.0)
  606. if retrieval_model["score_threshold_enabled"]
  607. else 0.0,
  608. reranking_model=retrieval_model.get("reranking_model", None)
  609. if retrieval_model["reranking_enable"]
  610. else None,
  611. reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
  612. weights=retrieval_model.get("weights", None),
  613. document_ids_filter=document_ids_filter,
  614. )
  615. all_documents.extend(documents)
  616. def to_dataset_retriever_tool(
  617. self,
  618. tenant_id: str,
  619. dataset_ids: list[str],
  620. retrieve_config: DatasetRetrieveConfigEntity,
  621. return_resource: bool,
  622. invoke_from: InvokeFrom,
  623. hit_callback: DatasetIndexToolCallbackHandler,
  624. user_id: str,
  625. inputs: dict,
  626. ) -> Optional[list[DatasetRetrieverBaseTool]]:
  627. """
  628. A dataset tool is a tool that can be used to retrieve information from a dataset
  629. :param tenant_id: tenant id
  630. :param dataset_ids: dataset ids
  631. :param retrieve_config: retrieve config
  632. :param return_resource: return resource
  633. :param invoke_from: invoke from
  634. :param hit_callback: hit callback
  635. """
  636. tools = []
  637. available_datasets = []
  638. for dataset_id in dataset_ids:
  639. # get dataset from dataset id
  640. dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
  641. dataset = db.session.scalar(dataset_stmt)
  642. # pass if dataset is not available
  643. if not dataset:
  644. continue
  645. # pass if dataset is not available
  646. if dataset and dataset.provider != "external" and dataset.available_document_count == 0:
  647. continue
  648. available_datasets.append(dataset)
  649. if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
  650. # get retrieval model config
  651. default_retrieval_model = {
  652. "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
  653. "reranking_enable": False,
  654. "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
  655. "top_k": 2,
  656. "score_threshold_enabled": False,
  657. }
  658. for dataset in available_datasets:
  659. retrieval_model_config = dataset.retrieval_model or default_retrieval_model
  660. # get top k
  661. top_k = retrieval_model_config["top_k"]
  662. # get score threshold
  663. score_threshold = None
  664. score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
  665. if score_threshold_enabled:
  666. score_threshold = retrieval_model_config.get("score_threshold")
  667. from core.tools.utils.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
  668. tool = DatasetRetrieverTool.from_dataset(
  669. dataset=dataset,
  670. top_k=top_k,
  671. score_threshold=score_threshold,
  672. hit_callbacks=[hit_callback],
  673. return_resource=return_resource,
  674. retriever_from=invoke_from.to_source(),
  675. retrieve_config=retrieve_config,
  676. user_id=user_id,
  677. inputs=inputs,
  678. )
  679. tools.append(tool)
  680. elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
  681. from core.tools.utils.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
  682. if retrieve_config.reranking_model is None:
  683. raise ValueError("Reranking model is required for multiple retrieval")
  684. tool = DatasetMultiRetrieverTool.from_dataset(
  685. dataset_ids=[dataset.id for dataset in available_datasets],
  686. tenant_id=tenant_id,
  687. top_k=retrieve_config.top_k or 4,
  688. score_threshold=retrieve_config.score_threshold,
  689. hit_callbacks=[hit_callback],
  690. return_resource=return_resource,
  691. retriever_from=invoke_from.to_source(),
  692. reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
  693. reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
  694. )
  695. tools.append(tool)
  696. return tools
  697. def calculate_keyword_score(self, query: str, documents: list[Document], top_k: int) -> list[Document]:
  698. """
  699. Calculate keywords scores
  700. :param query: search query
  701. :param documents: documents for reranking
  702. :param top_k: top k
  703. :return:
  704. """
  705. keyword_table_handler = JiebaKeywordTableHandler()
  706. query_keywords = keyword_table_handler.extract_keywords(query, None)
  707. documents_keywords = []
  708. for document in documents:
  709. if document.metadata is not None:
  710. # get the document keywords
  711. document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
  712. document.metadata["keywords"] = document_keywords
  713. documents_keywords.append(document_keywords)
  714. # Counter query keywords(TF)
  715. query_keyword_counts = Counter(query_keywords)
  716. # total documents
  717. total_documents = len(documents)
  718. # calculate all documents' keywords IDF
  719. all_keywords = set()
  720. for document_keywords in documents_keywords:
  721. all_keywords.update(document_keywords)
  722. keyword_idf = {}
  723. for keyword in all_keywords:
  724. # calculate include query keywords' documents
  725. doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords)
  726. # IDF
  727. keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1
  728. query_tfidf = {}
  729. for keyword, count in query_keyword_counts.items():
  730. tf = count
  731. idf = keyword_idf.get(keyword, 0)
  732. query_tfidf[keyword] = tf * idf
  733. # calculate all documents' TF-IDF
  734. documents_tfidf = []
  735. for document_keywords in documents_keywords:
  736. document_keyword_counts = Counter(document_keywords)
  737. document_tfidf = {}
  738. for keyword, count in document_keyword_counts.items():
  739. tf = count
  740. idf = keyword_idf.get(keyword, 0)
  741. document_tfidf[keyword] = tf * idf
  742. documents_tfidf.append(document_tfidf)
  743. def cosine_similarity(vec1, vec2):
  744. intersection = set(vec1.keys()) & set(vec2.keys())
  745. numerator = sum(vec1[x] * vec2[x] for x in intersection)
  746. sum1 = sum(vec1[x] ** 2 for x in vec1)
  747. sum2 = sum(vec2[x] ** 2 for x in vec2)
  748. denominator = math.sqrt(sum1) * math.sqrt(sum2)
  749. if not denominator:
  750. return 0.0
  751. else:
  752. return float(numerator) / denominator
  753. similarities = []
  754. for document_tfidf in documents_tfidf:
  755. similarity = cosine_similarity(query_tfidf, document_tfidf)
  756. similarities.append(similarity)
  757. for document, score in zip(documents, similarities):
  758. # format document
  759. if document.metadata is not None:
  760. document.metadata["score"] = score
  761. documents = sorted(documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
  762. return documents[:top_k] if top_k else documents
  763. def calculate_vector_score(
  764. self, all_documents: list[Document], top_k: int, score_threshold: float
  765. ) -> list[Document]:
  766. filter_documents = []
  767. for document in all_documents:
  768. if score_threshold is None or (document.metadata and document.metadata.get("score", 0) >= score_threshold):
  769. filter_documents.append(document)
  770. if not filter_documents:
  771. return []
  772. filter_documents = sorted(
  773. filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True
  774. )
  775. return filter_documents[:top_k] if top_k else filter_documents
  776. def get_metadata_filter_condition(
  777. self,
  778. dataset_ids: list,
  779. query: str,
  780. tenant_id: str,
  781. user_id: str,
  782. metadata_filtering_mode: str,
  783. metadata_model_config: ModelConfig,
  784. metadata_filtering_conditions: Optional[MetadataFilteringCondition],
  785. inputs: dict,
  786. ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
  787. document_query = db.session.query(DatasetDocument).where(
  788. DatasetDocument.dataset_id.in_(dataset_ids),
  789. DatasetDocument.indexing_status == "completed",
  790. DatasetDocument.enabled == True,
  791. DatasetDocument.archived == False,
  792. )
  793. filters = [] # type: ignore
  794. metadata_condition = None
  795. if metadata_filtering_mode == "disabled":
  796. return None, None
  797. elif metadata_filtering_mode == "automatic":
  798. automatic_metadata_filters = self._automatic_metadata_filter_func(
  799. dataset_ids, query, tenant_id, user_id, metadata_model_config
  800. )
  801. if automatic_metadata_filters:
  802. conditions = []
  803. for sequence, filter in enumerate(automatic_metadata_filters):
  804. self._process_metadata_filter_func(
  805. sequence,
  806. filter.get("condition"), # type: ignore
  807. filter.get("metadata_name"), # type: ignore
  808. filter.get("value"),
  809. filters, # type: ignore
  810. )
  811. conditions.append(
  812. Condition(
  813. name=filter.get("metadata_name"), # type: ignore
  814. comparison_operator=filter.get("condition"), # type: ignore
  815. value=filter.get("value"),
  816. )
  817. )
  818. metadata_condition = MetadataCondition(
  819. logical_operator=metadata_filtering_conditions.logical_operator
  820. if metadata_filtering_conditions
  821. else "or", # type: ignore
  822. conditions=conditions,
  823. )
  824. elif metadata_filtering_mode == "manual":
  825. if metadata_filtering_conditions:
  826. conditions = []
  827. for sequence, condition in enumerate(metadata_filtering_conditions.conditions): # type: ignore
  828. metadata_name = condition.name
  829. expected_value = condition.value
  830. if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
  831. if isinstance(expected_value, str):
  832. expected_value = self._replace_metadata_filter_value(expected_value, inputs)
  833. conditions.append(
  834. Condition(
  835. name=metadata_name,
  836. comparison_operator=condition.comparison_operator,
  837. value=expected_value,
  838. )
  839. )
  840. filters = self._process_metadata_filter_func(
  841. sequence,
  842. condition.comparison_operator,
  843. metadata_name,
  844. expected_value,
  845. filters,
  846. )
  847. metadata_condition = MetadataCondition(
  848. logical_operator=metadata_filtering_conditions.logical_operator,
  849. conditions=conditions,
  850. )
  851. else:
  852. raise ValueError("Invalid metadata filtering mode")
  853. if filters:
  854. if metadata_filtering_conditions and metadata_filtering_conditions.logical_operator == "and": # type: ignore
  855. document_query = document_query.where(and_(*filters))
  856. else:
  857. document_query = document_query.where(or_(*filters))
  858. documents = document_query.all()
  859. # group by dataset_id
  860. metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
  861. for document in documents:
  862. metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
  863. return metadata_filter_document_ids, metadata_condition
  864. def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str:
  865. if not inputs:
  866. return text
  867. def replacer(match):
  868. key = match.group(1)
  869. return str(inputs.get(key, f"{{{{{key}}}}}"))
  870. pattern = re.compile(r"\{\{(\w+)\}\}")
  871. output = pattern.sub(replacer, text)
  872. if isinstance(output, str):
  873. output = re.sub(r"[\r\n\t]+", " ", output).strip()
  874. return output
  875. def _automatic_metadata_filter_func(
  876. self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
  877. ) -> Optional[list[dict[str, Any]]]:
  878. # get all metadata field
  879. metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
  880. metadata_fields = db.session.scalars(metadata_stmt).all()
  881. all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
  882. # get metadata model config
  883. if metadata_model_config is None:
  884. raise ValueError("metadata_model_config is required")
  885. # get metadata model instance
  886. # fetch model config
  887. model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config)
  888. # fetch prompt messages
  889. prompt_messages, stop = self._get_prompt_template(
  890. model_config=model_config,
  891. mode=metadata_model_config.mode,
  892. metadata_fields=all_metadata_fields,
  893. query=query or "",
  894. )
  895. result_text = ""
  896. try:
  897. # handle invoke result
  898. invoke_result = cast(
  899. Generator[LLMResult, None, None],
  900. model_instance.invoke_llm(
  901. prompt_messages=prompt_messages,
  902. model_parameters=model_config.parameters,
  903. stop=stop,
  904. stream=True,
  905. user=user_id,
  906. ),
  907. )
  908. # handle invoke result
  909. result_text, usage = self._handle_invoke_result(invoke_result=invoke_result)
  910. result_text_json = parse_and_check_json_markdown(result_text, [])
  911. automatic_metadata_filters = []
  912. if "metadata_map" in result_text_json:
  913. metadata_map = result_text_json["metadata_map"]
  914. for item in metadata_map:
  915. if item.get("metadata_field_name") in all_metadata_fields:
  916. automatic_metadata_filters.append(
  917. {
  918. "metadata_name": item.get("metadata_field_name"),
  919. "value": item.get("metadata_field_value"),
  920. "condition": item.get("comparison_operator"),
  921. }
  922. )
  923. except Exception as e:
  924. return None
  925. return automatic_metadata_filters
  926. def _process_metadata_filter_func(
  927. self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list
  928. ):
  929. if value is None and condition not in ("empty", "not empty"):
  930. return
  931. key = f"{metadata_name}_{sequence}"
  932. key_value = f"{metadata_name}_{sequence}_value"
  933. match condition:
  934. case "contains":
  935. filters.append(
  936. (text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
  937. **{key: metadata_name, key_value: f"%{value}%"}
  938. )
  939. )
  940. case "not contains":
  941. filters.append(
  942. (text(f"documents.doc_metadata ->> :{key} NOT LIKE :{key_value}")).params(
  943. **{key: metadata_name, key_value: f"%{value}%"}
  944. )
  945. )
  946. case "start with":
  947. filters.append(
  948. (text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
  949. **{key: metadata_name, key_value: f"{value}%"}
  950. )
  951. )
  952. case "end with":
  953. filters.append(
  954. (text(f"documents.doc_metadata ->> :{key} LIKE :{key_value}")).params(
  955. **{key: metadata_name, key_value: f"%{value}"}
  956. )
  957. )
  958. case "is" | "=":
  959. if isinstance(value, str):
  960. filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"')
  961. else:
  962. filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) == value)
  963. case "is not" | "≠":
  964. if isinstance(value, str):
  965. filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"')
  966. else:
  967. filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) != value)
  968. case "empty":
  969. filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
  970. case "not empty":
  971. filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
  972. case "before" | "<":
  973. filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) < value)
  974. case "after" | ">":
  975. filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) > value)
  976. case "≤" | "<=":
  977. filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) <= value)
  978. case "≥" | ">=":
  979. filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) >= value)
  980. case _:
  981. pass
  982. return filters
  983. def _fetch_model_config(
  984. self, tenant_id: str, model: ModelConfig
  985. ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
  986. """
  987. Fetch model config
  988. """
  989. if model is None:
  990. raise ValueError("single_retrieval_config is required")
  991. model_name = model.name
  992. provider_name = model.provider
  993. model_manager = ModelManager()
  994. model_instance = model_manager.get_model_instance(
  995. tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
  996. )
  997. provider_model_bundle = model_instance.provider_model_bundle
  998. model_type_instance = model_instance.model_type_instance
  999. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  1000. model_credentials = model_instance.credentials
  1001. # check model
  1002. provider_model = provider_model_bundle.configuration.get_provider_model(
  1003. model=model_name, model_type=ModelType.LLM
  1004. )
  1005. if provider_model is None:
  1006. raise ValueError(f"Model {model_name} not exist.")
  1007. if provider_model.status == ModelStatus.NO_CONFIGURE:
  1008. raise ValueError(f"Model {model_name} credentials is not initialized.")
  1009. elif provider_model.status == ModelStatus.NO_PERMISSION:
  1010. raise ValueError(f"Dify Hosted OpenAI {model_name} currently not support.")
  1011. elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
  1012. raise ValueError(f"Model provider {provider_name} quota exceeded.")
  1013. # model config
  1014. completion_params = model.completion_params
  1015. stop = []
  1016. if "stop" in completion_params:
  1017. stop = completion_params["stop"]
  1018. del completion_params["stop"]
  1019. # get model mode
  1020. model_mode = model.mode
  1021. if not model_mode:
  1022. raise ValueError("LLM mode is required.")
  1023. model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
  1024. if not model_schema:
  1025. raise ValueError(f"Model {model_name} not exist.")
  1026. return model_instance, ModelConfigWithCredentialsEntity(
  1027. provider=provider_name,
  1028. model=model_name,
  1029. model_schema=model_schema,
  1030. mode=model_mode,
  1031. provider_model_bundle=provider_model_bundle,
  1032. credentials=model_credentials,
  1033. parameters=completion_params,
  1034. stop=stop,
  1035. )
  1036. def _get_prompt_template(
  1037. self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str
  1038. ):
  1039. model_mode = ModelMode(mode)
  1040. input_text = query
  1041. prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
  1042. if model_mode == ModelMode.CHAT:
  1043. prompt_template = []
  1044. system_prompt_messages = ChatModelMessage(role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT)
  1045. prompt_template.append(system_prompt_messages)
  1046. user_prompt_message_1 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1)
  1047. prompt_template.append(user_prompt_message_1)
  1048. assistant_prompt_message_1 = ChatModelMessage(
  1049. role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
  1050. )
  1051. prompt_template.append(assistant_prompt_message_1)
  1052. user_prompt_message_2 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2)
  1053. prompt_template.append(user_prompt_message_2)
  1054. assistant_prompt_message_2 = ChatModelMessage(
  1055. role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
  1056. )
  1057. prompt_template.append(assistant_prompt_message_2)
  1058. user_prompt_message_3 = ChatModelMessage(
  1059. role=PromptMessageRole.USER,
  1060. text=METADATA_FILTER_USER_PROMPT_3.format(
  1061. input_text=input_text,
  1062. metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
  1063. ),
  1064. )
  1065. prompt_template.append(user_prompt_message_3)
  1066. elif model_mode == ModelMode.COMPLETION:
  1067. prompt_template = CompletionModelPromptTemplate(
  1068. text=METADATA_FILTER_COMPLETION_PROMPT.format(
  1069. input_text=input_text,
  1070. metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
  1071. )
  1072. )
  1073. else:
  1074. raise ValueError(f"Model mode {model_mode} not support.")
  1075. prompt_transform = AdvancedPromptTransform()
  1076. prompt_messages = prompt_transform.get_prompt(
  1077. prompt_template=prompt_template,
  1078. inputs={},
  1079. query=query or "",
  1080. files=[],
  1081. context=None,
  1082. memory_config=None,
  1083. memory=None,
  1084. model_config=model_config,
  1085. )
  1086. stop = model_config.stop
  1087. return prompt_messages, stop
  1088. def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
  1089. """
  1090. Handle invoke result
  1091. :param invoke_result: invoke result
  1092. :return:
  1093. """
  1094. model = None
  1095. prompt_messages: list[PromptMessage] = []
  1096. full_text = ""
  1097. usage = None
  1098. for result in invoke_result:
  1099. text = result.delta.message.content
  1100. full_text += text
  1101. if not model:
  1102. model = result.model
  1103. if not prompt_messages:
  1104. prompt_messages = result.prompt_messages
  1105. if not usage and result.delta.usage:
  1106. usage = result.delta.usage
  1107. if not usage:
  1108. usage = LLMUsage.empty_usage()
  1109. return full_text, usage