| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from models.model import File | from models.model import File | ||||
| from core.tools.__base.tool_runtime import ToolRuntime | |||||
| from core.datasource.__base.datasource_runtime import DatasourceRuntime | |||||
| from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceProviderType | |||||
| from core.tools.entities.tool_entities import ( | from core.tools.entities.tool_entities import ( | ||||
| ToolEntity, | |||||
| ToolInvokeMessage, | ToolInvokeMessage, | ||||
| ToolParameter, | ToolParameter, | ||||
| ToolProviderType, | |||||
| ) | ) | ||||
| class Tool(ABC): | |||||
| class Datasource(ABC): | |||||
| """ | """ | ||||
| The base class of a tool | |||||
| The base class of a datasource | |||||
| """ | """ | ||||
| entity: ToolEntity | |||||
| runtime: ToolRuntime | |||||
| entity: DatasourceEntity | |||||
| runtime: DatasourceRuntime | |||||
| def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None: | |||||
| def __init__(self, entity: DatasourceEntity, runtime: DatasourceRuntime) -> None: | |||||
| self.entity = entity | self.entity = entity | ||||
| self.runtime = runtime | self.runtime = runtime | ||||
| def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool": | |||||
| def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "Datasource": | |||||
| """ | """ | ||||
| fork a new tool with metadata | |||||
| :return: the new tool | |||||
| fork a new datasource with metadata | |||||
| :return: the new datasource | |||||
| """ | """ | ||||
| return self.__class__( | return self.__class__( | ||||
| entity=self.entity.model_copy(), | entity=self.entity.model_copy(), | ||||
| ) | ) | ||||
| @abstractmethod | @abstractmethod | ||||
| def tool_provider_type(self) -> ToolProviderType: | |||||
| def datasource_provider_type(self) -> DatasourceProviderType: | |||||
| """ | """ | ||||
| get the tool provider type | |||||
| get the datasource provider type | |||||
| :return: the tool provider type | :return: the tool provider type | ||||
| """ | """ |
| from pydantic import Field | from pydantic import Field | ||||
| from core.app.entities.app_invoke_entities import InvokeFrom | from core.app.entities.app_invoke_entities import InvokeFrom | ||||
| from core.datasource.entities.datasource_entities import DatasourceInvokeFrom | |||||
| from core.tools.entities.tool_entities import ToolInvokeFrom | from core.tools.entities.tool_entities import ToolInvokeFrom | ||||
| class ToolRuntime(BaseModel): | |||||
| class DatasourceRuntime(BaseModel): | |||||
| """ | """ | ||||
| Meta data of a tool call processing | |||||
| Meta data of a datasource call processing | |||||
| """ | """ | ||||
| tenant_id: str | tenant_id: str | ||||
| runtime_parameters: dict[str, Any] = Field(default_factory=dict) | runtime_parameters: dict[str, Any] = Field(default_factory=dict) | ||||
| class FakeToolRuntime(ToolRuntime): | |||||
| class FakeDatasourceRuntime(DatasourceRuntime): | |||||
| """ | """ | ||||
| Fake tool runtime for testing | |||||
| Fake datasource runtime for testing | |||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__( | super().__init__( | ||||
| tenant_id="fake_tenant_id", | tenant_id="fake_tenant_id", | ||||
| tool_id="fake_tool_id", | |||||
| datasource_id="fake_datasource_id", | |||||
| invoke_from=InvokeFrom.DEBUGGER, | invoke_from=InvokeFrom.DEBUGGER, | ||||
| tool_invoke_from=ToolInvokeFrom.AGENT, | |||||
| datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE, | |||||
| credentials={}, | credentials={}, | ||||
| runtime_parameters={}, | runtime_parameters={}, | ||||
| ) | ) |
| from models.model import Message, MessageFile | from models.model import Message, MessageFile | ||||
| class ToolEngine: | |||||
| class DatasourceEngine: | |||||
| """ | """ | ||||
| Tool runtime engine take care of the tool executions. | |||||
| Datasource runtime engine take care of the datasource executions. | |||||
| """ | """ | ||||
| @staticmethod | @staticmethod |
| from collections.abc import Generator | from collections.abc import Generator | ||||
| from typing import Any, Optional | from typing import Any, Optional | ||||
| from core.datasource.__base.datasource_runtime import DatasourceRuntime | |||||
| from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceParameter, DatasourceProviderType | from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceParameter, DatasourceProviderType | ||||
| from core.plugin.manager.datasource import PluginDatasourceManager | |||||
| from core.plugin.manager.tool import PluginToolManager | from core.plugin.manager.tool import PluginToolManager | ||||
| from core.plugin.utils.converter import convert_parameters_to_plugin_format | from core.plugin.utils.converter import convert_parameters_to_plugin_format | ||||
| from core.tools.__base.tool import Tool | from core.tools.__base.tool import Tool | ||||
| from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType | from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType | ||||
| class DatasourceTool(Tool): | |||||
| class DatasourcePlugin(Datasource): | |||||
| tenant_id: str | tenant_id: str | ||||
| icon: str | icon: str | ||||
| plugin_unique_identifier: str | plugin_unique_identifier: str | ||||
| self, | self, | ||||
| user_id: str, | user_id: str, | ||||
| datasource_parameters: dict[str, Any], | datasource_parameters: dict[str, Any], | ||||
| conversation_id: Optional[str] = None, | |||||
| rag_pipeline_id: Optional[str] = None, | rag_pipeline_id: Optional[str] = None, | ||||
| message_id: Optional[str] = None, | |||||
| ) -> Generator[ToolInvokeMessage, None, None]: | ) -> Generator[ToolInvokeMessage, None, None]: | ||||
| manager = PluginToolManager() | |||||
| manager = PluginDatasourceManager() | |||||
| datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) | datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) | ||||
| yield from manager.invoke_first_step( | yield from manager.invoke_first_step( | ||||
| tenant_id=self.tenant_id, | tenant_id=self.tenant_id, | ||||
| user_id=user_id, | user_id=user_id, | ||||
| tool_provider=self.entity.identity.provider, | |||||
| tool_name=self.entity.identity.name, | |||||
| datasource_provider=self.entity.identity.provider, | |||||
| datasource_name=self.entity.identity.name, | |||||
| credentials=self.runtime.credentials, | credentials=self.runtime.credentials, | ||||
| tool_parameters=tool_parameters, | |||||
| conversation_id=conversation_id, | |||||
| app_id=app_id, | |||||
| message_id=message_id, | |||||
| datasource_parameters=datasource_parameters, | |||||
| rag_pipeline_id=rag_pipeline_id, | |||||
| ) | ) | ||||
| def _invoke_second_step( | def _invoke_second_step( | ||||
| self, | self, | ||||
| user_id: str, | user_id: str, | ||||
| datasource_parameters: dict[str, Any], | datasource_parameters: dict[str, Any], | ||||
| conversation_id: Optional[str] = None, | |||||
| rag_pipeline_id: Optional[str] = None, | rag_pipeline_id: Optional[str] = None, | ||||
| message_id: Optional[str] = None, | |||||
| ) -> Generator[ToolInvokeMessage, None, None]: | ) -> Generator[ToolInvokeMessage, None, None]: | ||||
| manager = PluginToolManager() | |||||
| manager = PluginDatasourceManager() | |||||
| datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) | datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) | ||||
| yield from manager.invoke( | yield from manager.invoke( | ||||
| tenant_id=self.tenant_id, | tenant_id=self.tenant_id, | ||||
| user_id=user_id, | user_id=user_id, | ||||
| tool_provider=self.entity.identity.provider, | |||||
| tool_name=self.entity.identity.name, | |||||
| datasource_provider=self.entity.identity.provider, | |||||
| datasource_name=self.entity.identity.name, | |||||
| credentials=self.runtime.credentials, | credentials=self.runtime.credentials, | ||||
| tool_parameters=tool_parameters, | |||||
| conversation_id=conversation_id, | |||||
| app_id=app_id, | |||||
| message_id=message_id, | |||||
| datasource_parameters=datasource_parameters, | |||||
| rag_pipeline_id=rag_pipeline_id, | |||||
| ) | ) | ||||
| def fork_tool_runtime(self, runtime: ToolRuntime) -> "DatasourceTool": | |||||
| return DatasourceTool( | |||||
| def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": | |||||
| return DatasourcePlugin( | |||||
| entity=self.entity, | entity=self.entity, | ||||
| runtime=runtime, | runtime=runtime, | ||||
| tenant_id=self.tenant_id, | tenant_id=self.tenant_id, | ||||
| def get_runtime_parameters( | def get_runtime_parameters( | ||||
| self, | self, | ||||
| conversation_id: Optional[str] = None, | |||||
| app_id: Optional[str] = None, | |||||
| message_id: Optional[str] = None, | |||||
| rag_pipeline_id: Optional[str] = None, | |||||
| ) -> list[DatasourceParameter]: | ) -> list[DatasourceParameter]: | ||||
| """ | """ | ||||
| get the runtime parameters | get the runtime parameters | ||||
| if self.runtime_parameters is not None: | if self.runtime_parameters is not None: | ||||
| return self.runtime_parameters | return self.runtime_parameters | ||||
| manager = PluginToolManager() | |||||
| manager = PluginDatasourceManager() | |||||
| self.runtime_parameters = manager.get_runtime_parameters( | self.runtime_parameters = manager.get_runtime_parameters( | ||||
| tenant_id=self.tenant_id, | tenant_id=self.tenant_id, | ||||
| user_id="", | user_id="", | ||||
| provider=self.entity.identity.provider, | provider=self.entity.identity.provider, | ||||
| tool=self.entity.identity.name, | |||||
| datasource=self.entity.identity.name, | |||||
| credentials=self.runtime.credentials, | credentials=self.runtime.credentials, | ||||
| conversation_id=conversation_id, | |||||
| app_id=app_id, | |||||
| message_id=message_id, | |||||
| rag_pipeline_id=rag_pipeline_id, | |||||
| ) | ) | ||||
| return self.runtime_parameters | return self.runtime_parameters |
| import threading | |||||
| from typing import Any | |||||
| from flask import Flask, current_app | |||||
| from pydantic import BaseModel, Field | |||||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||||
| from core.model_manager import ModelManager | |||||
| from core.model_runtime.entities.model_entities import ModelType | |||||
| from core.rag.datasource.retrieval_service import RetrievalService | |||||
| from core.rag.models.document import Document as RagDocument | |||||
| from core.rag.rerank.rerank_model import RerankModelRunner | |||||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||||
| from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool | |||||
| from extensions.ext_database import db | |||||
| from models.dataset import Dataset, Document, DocumentSegment | |||||
| default_retrieval_model: dict[str, Any] = { | |||||
| "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | |||||
| "reranking_enable": False, | |||||
| "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, | |||||
| "top_k": 2, | |||||
| "score_threshold_enabled": False, | |||||
| } | |||||
| class DatasetMultiRetrieverToolInput(BaseModel): | |||||
| query: str = Field(..., description="dataset multi retriever and rerank") | |||||
| class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): | |||||
| """Tool for querying multi dataset.""" | |||||
| name: str = "dataset_" | |||||
| args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput | |||||
| description: str = "dataset multi retriever and rerank. " | |||||
| dataset_ids: list[str] | |||||
| reranking_provider_name: str | |||||
| reranking_model_name: str | |||||
| @classmethod | |||||
| def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs): | |||||
| return cls( | |||||
| name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs | |||||
| ) | |||||
| def _run(self, query: str) -> str: | |||||
| threads = [] | |||||
| all_documents: list[RagDocument] = [] | |||||
| for dataset_id in self.dataset_ids: | |||||
| retrieval_thread = threading.Thread( | |||||
| target=self._retriever, | |||||
| kwargs={ | |||||
| "flask_app": current_app._get_current_object(), # type: ignore | |||||
| "dataset_id": dataset_id, | |||||
| "query": query, | |||||
| "all_documents": all_documents, | |||||
| "hit_callbacks": self.hit_callbacks, | |||||
| }, | |||||
| ) | |||||
| threads.append(retrieval_thread) | |||||
| retrieval_thread.start() | |||||
| for thread in threads: | |||||
| thread.join() | |||||
| # do rerank for searched documents | |||||
| model_manager = ModelManager() | |||||
| rerank_model_instance = model_manager.get_model_instance( | |||||
| tenant_id=self.tenant_id, | |||||
| provider=self.reranking_provider_name, | |||||
| model_type=ModelType.RERANK, | |||||
| model=self.reranking_model_name, | |||||
| ) | |||||
| rerank_runner = RerankModelRunner(rerank_model_instance) | |||||
| all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k) | |||||
| for hit_callback in self.hit_callbacks: | |||||
| hit_callback.on_tool_end(all_documents) | |||||
| document_score_list = {} | |||||
| for item in all_documents: | |||||
| if item.metadata and item.metadata.get("score"): | |||||
| document_score_list[item.metadata["doc_id"]] = item.metadata["score"] | |||||
| document_context_list = [] | |||||
| index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] | |||||
| segments = DocumentSegment.query.filter( | |||||
| DocumentSegment.dataset_id.in_(self.dataset_ids), | |||||
| DocumentSegment.completed_at.isnot(None), | |||||
| DocumentSegment.status == "completed", | |||||
| DocumentSegment.enabled == True, | |||||
| DocumentSegment.index_node_id.in_(index_node_ids), | |||||
| ).all() | |||||
| if segments: | |||||
| index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} | |||||
| sorted_segments = sorted( | |||||
| segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) | |||||
| ) | |||||
| for segment in sorted_segments: | |||||
| if segment.answer: | |||||
| document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") | |||||
| else: | |||||
| document_context_list.append(segment.get_sign_content()) | |||||
| if self.return_resource: | |||||
| context_list = [] | |||||
| resource_number = 1 | |||||
| for segment in sorted_segments: | |||||
| dataset = Dataset.query.filter_by(id=segment.dataset_id).first() | |||||
| document = Document.query.filter( | |||||
| Document.id == segment.document_id, | |||||
| Document.enabled == True, | |||||
| Document.archived == False, | |||||
| ).first() | |||||
| if dataset and document: | |||||
| source = { | |||||
| "position": resource_number, | |||||
| "dataset_id": dataset.id, | |||||
| "dataset_name": dataset.name, | |||||
| "document_id": document.id, | |||||
| "document_name": document.name, | |||||
| "data_source_type": document.data_source_type, | |||||
| "segment_id": segment.id, | |||||
| "retriever_from": self.retriever_from, | |||||
| "score": document_score_list.get(segment.index_node_id, None), | |||||
| "doc_metadata": document.doc_metadata, | |||||
| } | |||||
| if self.retriever_from == "dev": | |||||
| source["hit_count"] = segment.hit_count | |||||
| source["word_count"] = segment.word_count | |||||
| source["segment_position"] = segment.position | |||||
| source["index_node_hash"] = segment.index_node_hash | |||||
| if segment.answer: | |||||
| source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" | |||||
| else: | |||||
| source["content"] = segment.content | |||||
| context_list.append(source) | |||||
| resource_number += 1 | |||||
| for hit_callback in self.hit_callbacks: | |||||
| hit_callback.return_retriever_resource_info(context_list) | |||||
| return str("\n".join(document_context_list)) | |||||
| return "" | |||||
| raise RuntimeError("not segments found") | |||||
| def _retriever( | |||||
| self, | |||||
| flask_app: Flask, | |||||
| dataset_id: str, | |||||
| query: str, | |||||
| all_documents: list, | |||||
| hit_callbacks: list[DatasetIndexToolCallbackHandler], | |||||
| ): | |||||
| with flask_app.app_context(): | |||||
| dataset = ( | |||||
| db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() | |||||
| ) | |||||
| if not dataset: | |||||
| return [] | |||||
| for hit_callback in hit_callbacks: | |||||
| hit_callback.on_query(query, dataset.id) | |||||
| # get retrieval model , if the model is not setting , using default | |||||
| retrieval_model = dataset.retrieval_model or default_retrieval_model | |||||
| if dataset.indexing_technique == "economy": | |||||
| # use keyword table query | |||||
| documents = RetrievalService.retrieve( | |||||
| retrieval_method="keyword_search", | |||||
| dataset_id=dataset.id, | |||||
| query=query, | |||||
| top_k=retrieval_model.get("top_k") or 2, | |||||
| ) | |||||
| if documents: | |||||
| all_documents.extend(documents) | |||||
| else: | |||||
| if self.top_k > 0: | |||||
| # retrieval source | |||||
| documents = RetrievalService.retrieve( | |||||
| retrieval_method=retrieval_model["search_method"], | |||||
| dataset_id=dataset.id, | |||||
| query=query, | |||||
| top_k=retrieval_model.get("top_k") or 2, | |||||
| score_threshold=retrieval_model.get("score_threshold", 0.0) | |||||
| if retrieval_model["score_threshold_enabled"] | |||||
| else 0.0, | |||||
| reranking_model=retrieval_model.get("reranking_model", None) | |||||
| if retrieval_model["reranking_enable"] | |||||
| else None, | |||||
| reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", | |||||
| weights=retrieval_model.get("weights", None), | |||||
| ) | |||||
| all_documents.extend(documents) |
| from abc import abstractmethod | |||||
| from typing import Any, Optional | |||||
| from msal_extensions.persistence import ABC # type: ignore | |||||
| from pydantic import BaseModel, ConfigDict | |||||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||||
| class DatasetRetrieverBaseTool(BaseModel, ABC): | |||||
| """Tool for querying a Dataset.""" | |||||
| name: str = "dataset" | |||||
| description: str = "use this to retrieve a dataset. " | |||||
| tenant_id: str | |||||
| top_k: int = 2 | |||||
| score_threshold: Optional[float] = None | |||||
| hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] | |||||
| return_resource: bool | |||||
| retriever_from: str | |||||
| model_config = ConfigDict(arbitrary_types_allowed=True) | |||||
| @abstractmethod | |||||
| def _run( | |||||
| self, | |||||
| *args: Any, | |||||
| **kwargs: Any, | |||||
| ) -> Any: | |||||
| """Use the tool. | |||||
| Add run_manager: Optional[CallbackManagerForToolRun] = None | |||||
| to child implementations to enable tracing, | |||||
| """ |
| from typing import Any | |||||
| from pydantic import BaseModel, Field | |||||
| from core.rag.datasource.retrieval_service import RetrievalService | |||||
| from core.rag.entities.context_entities import DocumentContext | |||||
| from core.rag.models.document import Document as RetrievalDocument | |||||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||||
| from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool | |||||
| from extensions.ext_database import db | |||||
| from models.dataset import Dataset | |||||
| from models.dataset import Document as DatasetDocument | |||||
| from services.external_knowledge_service import ExternalDatasetService | |||||
| default_retrieval_model = { | |||||
| "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | |||||
| "reranking_enable": False, | |||||
| "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, | |||||
| "reranking_mode": "reranking_model", | |||||
| "top_k": 2, | |||||
| "score_threshold_enabled": False, | |||||
| } | |||||
| class DatasetRetrieverToolInput(BaseModel): | |||||
| query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.") | |||||
| class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||||
| """Tool for querying a Dataset.""" | |||||
| name: str = "dataset" | |||||
| args_schema: type[BaseModel] = DatasetRetrieverToolInput | |||||
| description: str = "use this to retrieve a dataset. " | |||||
| dataset_id: str | |||||
| @classmethod | |||||
| def from_dataset(cls, dataset: Dataset, **kwargs): | |||||
| description = dataset.description | |||||
| if not description: | |||||
| description = "useful for when you want to answer queries about the " + dataset.name | |||||
| description = description.replace("\n", "").replace("\r", "") | |||||
| return cls( | |||||
| name=f"dataset_{dataset.id.replace('-', '_')}", | |||||
| tenant_id=dataset.tenant_id, | |||||
| dataset_id=dataset.id, | |||||
| description=description, | |||||
| **kwargs, | |||||
| ) | |||||
| def _run(self, query: str) -> str: | |||||
| dataset = ( | |||||
| db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first() | |||||
| ) | |||||
| if not dataset: | |||||
| return "" | |||||
| for hit_callback in self.hit_callbacks: | |||||
| hit_callback.on_query(query, dataset.id) | |||||
| if dataset.provider == "external": | |||||
| results = [] | |||||
| external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( | |||||
| tenant_id=dataset.tenant_id, | |||||
| dataset_id=dataset.id, | |||||
| query=query, | |||||
| external_retrieval_parameters=dataset.retrieval_model, | |||||
| ) | |||||
| for external_document in external_documents: | |||||
| document = RetrievalDocument( | |||||
| page_content=external_document.get("content"), | |||||
| metadata=external_document.get("metadata"), | |||||
| provider="external", | |||||
| ) | |||||
| if document.metadata is not None: | |||||
| document.metadata["score"] = external_document.get("score") | |||||
| document.metadata["title"] = external_document.get("title") | |||||
| document.metadata["dataset_id"] = dataset.id | |||||
| document.metadata["dataset_name"] = dataset.name | |||||
| results.append(document) | |||||
| # deal with external documents | |||||
| context_list = [] | |||||
| for position, item in enumerate(results, start=1): | |||||
| if item.metadata is not None: | |||||
| source = { | |||||
| "position": position, | |||||
| "dataset_id": item.metadata.get("dataset_id"), | |||||
| "dataset_name": item.metadata.get("dataset_name"), | |||||
| "document_name": item.metadata.get("title"), | |||||
| "data_source_type": "external", | |||||
| "retriever_from": self.retriever_from, | |||||
| "score": item.metadata.get("score"), | |||||
| "title": item.metadata.get("title"), | |||||
| "content": item.page_content, | |||||
| } | |||||
| context_list.append(source) | |||||
| for hit_callback in self.hit_callbacks: | |||||
| hit_callback.return_retriever_resource_info(context_list) | |||||
| return str("\n".join([item.page_content for item in results])) | |||||
| else: | |||||
| # get retrieval model , if the model is not setting , using default | |||||
| retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model | |||||
| if dataset.indexing_technique == "economy": | |||||
| # use keyword table query | |||||
| documents = RetrievalService.retrieve( | |||||
| retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k | |||||
| ) | |||||
| return str("\n".join([document.page_content for document in documents])) | |||||
| else: | |||||
| if self.top_k > 0: | |||||
| # retrieval source | |||||
| documents = RetrievalService.retrieve( | |||||
| retrieval_method=retrieval_model.get("search_method", "semantic_search"), | |||||
| dataset_id=dataset.id, | |||||
| query=query, | |||||
| top_k=self.top_k, | |||||
| score_threshold=retrieval_model.get("score_threshold", 0.0) | |||||
| if retrieval_model["score_threshold_enabled"] | |||||
| else 0.0, | |||||
| reranking_model=retrieval_model.get("reranking_model") | |||||
| if retrieval_model["reranking_enable"] | |||||
| else None, | |||||
| reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", | |||||
| weights=retrieval_model.get("weights"), | |||||
| ) | |||||
| else: | |||||
| documents = [] | |||||
| for hit_callback in self.hit_callbacks: | |||||
| hit_callback.on_tool_end(documents) | |||||
| document_score_list = {} | |||||
| if dataset.indexing_technique != "economy": | |||||
| for item in documents: | |||||
| if item.metadata is not None and item.metadata.get("score"): | |||||
| document_score_list[item.metadata["doc_id"]] = item.metadata["score"] | |||||
| document_context_list = [] | |||||
| records = RetrievalService.format_retrieval_documents(documents) | |||||
| if records: | |||||
| for record in records: | |||||
| segment = record.segment | |||||
| if segment.answer: | |||||
| document_context_list.append( | |||||
| DocumentContext( | |||||
| content=f"question:{segment.get_sign_content()} answer:{segment.answer}", | |||||
| score=record.score, | |||||
| ) | |||||
| ) | |||||
| else: | |||||
| document_context_list.append( | |||||
| DocumentContext( | |||||
| content=segment.get_sign_content(), | |||||
| score=record.score, | |||||
| ) | |||||
| ) | |||||
| retrieval_resource_list = [] | |||||
| if self.return_resource: | |||||
| for record in records: | |||||
| segment = record.segment | |||||
| dataset = Dataset.query.filter_by(id=segment.dataset_id).first() | |||||
| document = DatasetDocument.query.filter( | |||||
| DatasetDocument.id == segment.document_id, | |||||
| DatasetDocument.enabled == True, | |||||
| DatasetDocument.archived == False, | |||||
| ).first() | |||||
| if dataset and document: | |||||
| source = { | |||||
| "dataset_id": dataset.id, | |||||
| "dataset_name": dataset.name, | |||||
| "document_id": document.id, # type: ignore | |||||
| "document_name": document.name, # type: ignore | |||||
| "data_source_type": document.data_source_type, # type: ignore | |||||
| "segment_id": segment.id, | |||||
| "retriever_from": self.retriever_from, | |||||
| "score": record.score or 0.0, | |||||
| "doc_metadata": document.doc_metadata, # type: ignore | |||||
| } | |||||
| if self.retriever_from == "dev": | |||||
| source["hit_count"] = segment.hit_count | |||||
| source["word_count"] = segment.word_count | |||||
| source["segment_position"] = segment.position | |||||
| source["index_node_hash"] = segment.index_node_hash | |||||
| if segment.answer: | |||||
| source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" | |||||
| else: | |||||
| source["content"] = segment.content | |||||
| retrieval_resource_list.append(source) | |||||
| if self.return_resource and retrieval_resource_list: | |||||
| retrieval_resource_list = sorted( | |||||
| retrieval_resource_list, | |||||
| key=lambda x: x.get("score") or 0.0, | |||||
| reverse=True, | |||||
| ) | |||||
| for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore | |||||
| item["position"] = position # type: ignore | |||||
| for hit_callback in self.hit_callbacks: | |||||
| hit_callback.return_retriever_resource_info(retrieval_resource_list) | |||||
| if document_context_list: | |||||
| document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) | |||||
| return str("\n".join([document_context.content for document_context in document_context_list])) | |||||
| return "" |
| from collections.abc import Generator | |||||
| from typing import Any, Optional | |||||
| from core.app.app_config.entities import DatasetRetrieveConfigEntity | |||||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||||
| from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | |||||
| from core.tools.__base.tool import Tool | |||||
| from core.tools.__base.tool_runtime import ToolRuntime | |||||
| from core.tools.entities.common_entities import I18nObject | |||||
| from core.tools.entities.tool_entities import ( | |||||
| ToolDescription, | |||||
| ToolEntity, | |||||
| ToolIdentity, | |||||
| ToolInvokeMessage, | |||||
| ToolParameter, | |||||
| ToolProviderType, | |||||
| ) | |||||
| from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool | |||||
| class DatasetRetrieverTool(Tool): | |||||
| retrieval_tool: DatasetRetrieverBaseTool | |||||
| def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None: | |||||
| super().__init__(entity, runtime) | |||||
| self.retrieval_tool = retrieval_tool | |||||
| @staticmethod | |||||
| def get_dataset_tools( | |||||
| tenant_id: str, | |||||
| dataset_ids: list[str], | |||||
| retrieve_config: DatasetRetrieveConfigEntity | None, | |||||
| return_resource: bool, | |||||
| invoke_from: InvokeFrom, | |||||
| hit_callback: DatasetIndexToolCallbackHandler, | |||||
| ) -> list["DatasetRetrieverTool"]: | |||||
| """ | |||||
| get dataset tool | |||||
| """ | |||||
| # check if retrieve_config is valid | |||||
| if dataset_ids is None or len(dataset_ids) == 0: | |||||
| return [] | |||||
| if retrieve_config is None: | |||||
| return [] | |||||
| feature = DatasetRetrieval() | |||||
| # save original retrieve strategy, and set retrieve strategy to SINGLE | |||||
| # Agent only support SINGLE mode | |||||
| original_retriever_mode = retrieve_config.retrieve_strategy | |||||
| retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE | |||||
| retrieval_tools = feature.to_dataset_retriever_tool( | |||||
| tenant_id=tenant_id, | |||||
| dataset_ids=dataset_ids, | |||||
| retrieve_config=retrieve_config, | |||||
| return_resource=return_resource, | |||||
| invoke_from=invoke_from, | |||||
| hit_callback=hit_callback, | |||||
| ) | |||||
| if retrieval_tools is None or len(retrieval_tools) == 0: | |||||
| return [] | |||||
| # restore retrieve strategy | |||||
| retrieve_config.retrieve_strategy = original_retriever_mode | |||||
| # convert retrieval tools to Tools | |||||
| tools = [] | |||||
| for retrieval_tool in retrieval_tools: | |||||
| tool = DatasetRetrieverTool( | |||||
| retrieval_tool=retrieval_tool, | |||||
| entity=ToolEntity( | |||||
| identity=ToolIdentity( | |||||
| provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="") | |||||
| ), | |||||
| parameters=[], | |||||
| description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description), | |||||
| ), | |||||
| runtime=ToolRuntime(tenant_id=tenant_id), | |||||
| ) | |||||
| tools.append(tool) | |||||
| return tools | |||||
| def get_runtime_parameters( | |||||
| self, | |||||
| conversation_id: Optional[str] = None, | |||||
| app_id: Optional[str] = None, | |||||
| message_id: Optional[str] = None, | |||||
| ) -> list[ToolParameter]: | |||||
| return [ | |||||
| ToolParameter( | |||||
| name="query", | |||||
| label=I18nObject(en_US="", zh_Hans=""), | |||||
| human_description=I18nObject(en_US="", zh_Hans=""), | |||||
| type=ToolParameter.ToolParameterType.STRING, | |||||
| form=ToolParameter.ToolParameterForm.LLM, | |||||
| llm_description="Query for the dataset to be used to retrieve the dataset.", | |||||
| required=True, | |||||
| default="", | |||||
| placeholder=I18nObject(en_US="", zh_Hans=""), | |||||
| ), | |||||
| ] | |||||
| def tool_provider_type(self) -> ToolProviderType: | |||||
| return ToolProviderType.DATASET_RETRIEVAL | |||||
| def _invoke( | |||||
| self, | |||||
| user_id: str, | |||||
| tool_parameters: dict[str, Any], | |||||
| conversation_id: Optional[str] = None, | |||||
| app_id: Optional[str] = None, | |||||
| message_id: Optional[str] = None, | |||||
| ) -> Generator[ToolInvokeMessage, None, None]: | |||||
| """ | |||||
| invoke dataset retriever tool | |||||
| """ | |||||
| query = tool_parameters.get("query") | |||||
| if not query: | |||||
| yield self.create_text_message(text="please input query") | |||||
| else: | |||||
| # invoke dataset retriever tool | |||||
| result = self.retrieval_tool._run(query=query) | |||||
| yield self.create_text_message(text=result) | |||||
| def validate_credentials( | |||||
| self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False | |||||
| ) -> str | None: | |||||
| """ | |||||
| validate the credentials for dataset retriever tool | |||||
| """ | |||||
| pass |
| """ | |||||
| For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc. | |||||
| Therefore, a model manager is needed to list/invoke/validate models. | |||||
| """ | |||||
| import json | |||||
| from typing import Optional, cast | |||||
| from core.model_manager import ModelManager | |||||
| from core.model_runtime.entities.llm_entities import LLMResult | |||||
| from core.model_runtime.entities.message_entities import PromptMessage | |||||
| from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType | |||||
| from core.model_runtime.errors.invoke import ( | |||||
| InvokeAuthorizationError, | |||||
| InvokeBadRequestError, | |||||
| InvokeConnectionError, | |||||
| InvokeRateLimitError, | |||||
| InvokeServerUnavailableError, | |||||
| ) | |||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||||
| from extensions.ext_database import db | |||||
| from models.tools import ToolModelInvoke | |||||
| class InvokeModelError(Exception): | |||||
| pass | |||||
| class ModelInvocationUtils: | |||||
| @staticmethod | |||||
| def get_max_llm_context_tokens( | |||||
| tenant_id: str, | |||||
| ) -> int: | |||||
| """ | |||||
| get max llm context tokens of the model | |||||
| """ | |||||
| model_manager = ModelManager() | |||||
| model_instance = model_manager.get_default_model_instance( | |||||
| tenant_id=tenant_id, | |||||
| model_type=ModelType.LLM, | |||||
| ) | |||||
| if not model_instance: | |||||
| raise InvokeModelError("Model not found") | |||||
| llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) | |||||
| schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) | |||||
| if not schema: | |||||
| raise InvokeModelError("No model schema found") | |||||
| max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) | |||||
| if max_tokens is None: | |||||
| return 2048 | |||||
| return max_tokens | |||||
| @staticmethod | |||||
| def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int: | |||||
| """ | |||||
| calculate tokens from prompt messages and model parameters | |||||
| """ | |||||
| # get model instance | |||||
| model_manager = ModelManager() | |||||
| model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM) | |||||
| if not model_instance: | |||||
| raise InvokeModelError("Model not found") | |||||
| # get tokens | |||||
| tokens = model_instance.get_llm_num_tokens(prompt_messages) | |||||
| return tokens | |||||
| @staticmethod | |||||
| def invoke( | |||||
| user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage] | |||||
| ) -> LLMResult: | |||||
| """ | |||||
| invoke model with parameters in user's own context | |||||
| :param user_id: user id | |||||
| :param tenant_id: tenant id, the tenant id of the creator of the tool | |||||
| :param tool_type: tool type | |||||
| :param tool_name: tool name | |||||
| :param prompt_messages: prompt messages | |||||
| :return: AssistantPromptMessage | |||||
| """ | |||||
| # get model manager | |||||
| model_manager = ModelManager() | |||||
| # get model instance | |||||
| model_instance = model_manager.get_default_model_instance( | |||||
| tenant_id=tenant_id, | |||||
| model_type=ModelType.LLM, | |||||
| ) | |||||
| # get prompt tokens | |||||
| prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) | |||||
| model_parameters = { | |||||
| "temperature": 0.8, | |||||
| "top_p": 0.8, | |||||
| } | |||||
| # create tool model invoke | |||||
| tool_model_invoke = ToolModelInvoke( | |||||
| user_id=user_id, | |||||
| tenant_id=tenant_id, | |||||
| provider=model_instance.provider, | |||||
| tool_type=tool_type, | |||||
| tool_name=tool_name, | |||||
| model_parameters=json.dumps(model_parameters), | |||||
| prompt_messages=json.dumps(jsonable_encoder(prompt_messages)), | |||||
| model_response="", | |||||
| prompt_tokens=prompt_tokens, | |||||
| answer_tokens=0, | |||||
| answer_unit_price=0, | |||||
| answer_price_unit=0, | |||||
| provider_response_latency=0, | |||||
| total_price=0, | |||||
| currency="USD", | |||||
| ) | |||||
| db.session.add(tool_model_invoke) | |||||
| db.session.commit() | |||||
| try: | |||||
| response: LLMResult = cast( | |||||
| LLMResult, | |||||
| model_instance.invoke_llm( | |||||
| prompt_messages=prompt_messages, | |||||
| model_parameters=model_parameters, | |||||
| tools=[], | |||||
| stop=[], | |||||
| stream=False, | |||||
| user=user_id, | |||||
| callbacks=[], | |||||
| ), | |||||
| ) | |||||
| except InvokeRateLimitError as e: | |||||
| raise InvokeModelError(f"Invoke rate limit error: {e}") | |||||
| except InvokeBadRequestError as e: | |||||
| raise InvokeModelError(f"Invoke bad request error: {e}") | |||||
| except InvokeConnectionError as e: | |||||
| raise InvokeModelError(f"Invoke connection error: {e}") | |||||
| except InvokeAuthorizationError as e: | |||||
| raise InvokeModelError("Invoke authorization error") | |||||
| except InvokeServerUnavailableError as e: | |||||
| raise InvokeModelError(f"Invoke server unavailable error: {e}") | |||||
| except Exception as e: | |||||
| raise InvokeModelError(f"Invoke error: {e}") | |||||
| # update tool model invoke | |||||
| tool_model_invoke.model_response = response.message.content | |||||
| if response.usage: | |||||
| tool_model_invoke.answer_tokens = response.usage.completion_tokens | |||||
| tool_model_invoke.answer_unit_price = response.usage.completion_unit_price | |||||
| tool_model_invoke.answer_price_unit = response.usage.completion_price_unit | |||||
| tool_model_invoke.provider_response_latency = response.usage.latency | |||||
| tool_model_invoke.total_price = response.usage.total_price | |||||
| tool_model_invoke.currency = response.usage.currency | |||||
| db.session.commit() | |||||
| return response |
| import re | |||||
| def get_image_upload_file_ids(content): | |||||
| pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)" | |||||
| matches = re.findall(pattern, content) | |||||
| image_upload_file_ids = [] | |||||
| for match in matches: | |||||
| if match[1] == "file-preview": | |||||
| content_pattern = r"files/([^/]+)/file-preview" | |||||
| else: | |||||
| content_pattern = r"files/([^/]+)/image-preview" | |||||
| content_match = re.search(content_pattern, match[0]) | |||||
| if content_match: | |||||
| image_upload_file_id = content_match.group(1) | |||||
| image_upload_file_ids.append(image_upload_file_id) | |||||
| return image_upload_file_ids |
| import hashlib | |||||
| import json | |||||
| import mimetypes | |||||
| import os | |||||
| import re | |||||
| import site | |||||
| import subprocess | |||||
| import tempfile | |||||
| import unicodedata | |||||
| from contextlib import contextmanager | |||||
| from pathlib import Path | |||||
| from typing import Any, Literal, Optional, cast | |||||
| from urllib.parse import unquote | |||||
| import chardet | |||||
| import cloudscraper # type: ignore | |||||
| from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore | |||||
| from regex import regex # type: ignore | |||||
| from core.helper import ssrf_proxy | |||||
| from core.rag.extractor import extract_processor | |||||
| from core.rag.extractor.extract_processor import ExtractProcessor | |||||
| FULL_TEMPLATE = """ | |||||
| TITLE: {title} | |||||
| AUTHORS: {authors} | |||||
| PUBLISH DATE: {publish_date} | |||||
| TOP_IMAGE_URL: {top_image} | |||||
| TEXT: | |||||
| {text} | |||||
| """ | |||||
| def page_result(text: str, cursor: int, max_length: int) -> str: | |||||
| """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" | |||||
| return text[cursor : cursor + max_length] | |||||
| def get_url(url: str, user_agent: Optional[str] = None) -> str: | |||||
| """Fetch URL and return the contents as a string.""" | |||||
| headers = { | |||||
| "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)" | |||||
| " Chrome/91.0.4472.124 Safari/537.36" | |||||
| } | |||||
| if user_agent: | |||||
| headers["User-Agent"] = user_agent | |||||
| main_content_type = None | |||||
| supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] | |||||
| response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10)) | |||||
| if response.status_code == 200: | |||||
| # check content-type | |||||
| content_type = response.headers.get("Content-Type") | |||||
| if content_type: | |||||
| main_content_type = response.headers.get("Content-Type").split(";")[0].strip() | |||||
| else: | |||||
| content_disposition = response.headers.get("Content-Disposition", "") | |||||
| filename_match = re.search(r'filename="([^"]+)"', content_disposition) | |||||
| if filename_match: | |||||
| filename = unquote(filename_match.group(1)) | |||||
| extension = re.search(r"\.(\w+)$", filename) | |||||
| if extension: | |||||
| main_content_type = mimetypes.guess_type(filename)[0] | |||||
| if main_content_type not in supported_content_types: | |||||
| return "Unsupported content-type [{}] of URL.".format(main_content_type) | |||||
| if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: | |||||
| return cast(str, ExtractProcessor.load_from_url(url, return_text=True)) | |||||
| response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) | |||||
| elif response.status_code == 403: | |||||
| scraper = cloudscraper.create_scraper() | |||||
| scraper.perform_request = ssrf_proxy.make_request | |||||
| response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) | |||||
| if response.status_code != 200: | |||||
| return "URL returned status code {}.".format(response.status_code) | |||||
| # Detect encoding using chardet | |||||
| detected_encoding = chardet.detect(response.content) | |||||
| encoding = detected_encoding["encoding"] | |||||
| if encoding: | |||||
| try: | |||||
| content = response.content.decode(encoding) | |||||
| except (UnicodeDecodeError, TypeError): | |||||
| content = response.text | |||||
| else: | |||||
| content = response.text | |||||
| a = extract_using_readabilipy(content) | |||||
| if not a["plain_text"] or not a["plain_text"].strip(): | |||||
| return "" | |||||
| res = FULL_TEMPLATE.format( | |||||
| title=a["title"], | |||||
| authors=a["byline"], | |||||
| publish_date=a["date"], | |||||
| top_image="", | |||||
| text=a["plain_text"] or "", | |||||
| ) | |||||
| return res | |||||
| def extract_using_readabilipy(html): | |||||
| with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html: | |||||
| f_html.write(html) | |||||
| f_html.close() | |||||
| html_path = f_html.name | |||||
| # Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file | |||||
| article_json_path = html_path + ".json" | |||||
| jsdir = os.path.join(find_module_path("readabilipy"), "javascript") | |||||
| with chdir(jsdir): | |||||
| subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path]) | |||||
| # Read output of call to Readability.parse() from JSON file and return as Python dictionary | |||||
| input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8")) | |||||
| # Deleting files after processing | |||||
| os.unlink(article_json_path) | |||||
| os.unlink(html_path) | |||||
| article_json: dict[str, Any] = { | |||||
| "title": None, | |||||
| "byline": None, | |||||
| "date": None, | |||||
| "content": None, | |||||
| "plain_content": None, | |||||
| "plain_text": None, | |||||
| } | |||||
| # Populate article fields from readability fields where present | |||||
| if input_json: | |||||
| if input_json.get("title"): | |||||
| article_json["title"] = input_json["title"] | |||||
| if input_json.get("byline"): | |||||
| article_json["byline"] = input_json["byline"] | |||||
| if input_json.get("date"): | |||||
| article_json["date"] = input_json["date"] | |||||
| if input_json.get("content"): | |||||
| article_json["content"] = input_json["content"] | |||||
| article_json["plain_content"] = plain_content(article_json["content"], False, False) | |||||
| article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"]) | |||||
| if input_json.get("textContent"): | |||||
| article_json["plain_text"] = input_json["textContent"] | |||||
| article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"]) | |||||
| return article_json | |||||
| def find_module_path(module_name): | |||||
| for package_path in site.getsitepackages(): | |||||
| potential_path = os.path.join(package_path, module_name) | |||||
| if os.path.exists(potential_path): | |||||
| return potential_path | |||||
| return None | |||||
| @contextmanager | |||||
| def chdir(path): | |||||
| """Change directory in context and return to original on exit""" | |||||
| # From https://stackoverflow.com/a/37996581, couldn't find a built-in | |||||
| original_path = os.getcwd() | |||||
| os.chdir(path) | |||||
| try: | |||||
| yield | |||||
| finally: | |||||
| os.chdir(original_path) | |||||
| def extract_text_blocks_as_plain_text(paragraph_html): | |||||
| # Load article as DOM | |||||
| soup = BeautifulSoup(paragraph_html, "html.parser") | |||||
| # Select all lists | |||||
| list_elements = soup.find_all(["ul", "ol"]) | |||||
| # Prefix text in all list items with "* " and make lists paragraphs | |||||
| for list_element in list_elements: | |||||
| plain_items = "".join( | |||||
| list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")])) | |||||
| ) | |||||
| list_element.string = plain_items | |||||
| list_element.name = "p" | |||||
| # Select all text blocks | |||||
| text_blocks = [s.parent for s in soup.find_all(string=True)] | |||||
| text_blocks = [plain_text_leaf_node(block) for block in text_blocks] | |||||
| # Drop empty paragraphs | |||||
| text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks)) | |||||
| return text_blocks | |||||
| def plain_text_leaf_node(element): | |||||
| # Extract all text, stripped of any child HTML elements and normalize it | |||||
| plain_text = normalize_text(element.get_text()) | |||||
| if plain_text != "" and element.name == "li": | |||||
| plain_text = "* {}, ".format(plain_text) | |||||
| if plain_text == "": | |||||
| plain_text = None | |||||
| if "data-node-index" in element.attrs: | |||||
| plain = {"node_index": element["data-node-index"], "text": plain_text} | |||||
| else: | |||||
| plain = {"text": plain_text} | |||||
| return plain | |||||
| def plain_content(readability_content, content_digests, node_indexes): | |||||
| # Load article as DOM | |||||
| soup = BeautifulSoup(readability_content, "html.parser") | |||||
| # Make all elements plain | |||||
| elements = plain_elements(soup.contents, content_digests, node_indexes) | |||||
| if node_indexes: | |||||
| # Add node index attributes to nodes | |||||
| elements = [add_node_indexes(element) for element in elements] | |||||
| # Replace article contents with plain elements | |||||
| soup.contents = elements | |||||
| return str(soup) | |||||
| def plain_elements(elements, content_digests, node_indexes): | |||||
| # Get plain content versions of all elements | |||||
| elements = [plain_element(element, content_digests, node_indexes) for element in elements] | |||||
| if content_digests: | |||||
| # Add content digest attribute to nodes | |||||
| elements = [add_content_digest(element) for element in elements] | |||||
| return elements | |||||
| def plain_element(element, content_digests, node_indexes): | |||||
| # For lists, we make each item plain text | |||||
| if is_leaf(element): | |||||
| # For leaf node elements, extract the text content, discarding any HTML tags | |||||
| # 1. Get element contents as text | |||||
| plain_text = element.get_text() | |||||
| # 2. Normalize the extracted text string to a canonical representation | |||||
| plain_text = normalize_text(plain_text) | |||||
| # 3. Update element content to be plain text | |||||
| element.string = plain_text | |||||
| elif is_text(element): | |||||
| if is_non_printing(element): | |||||
| # The simplified HTML may have come from Readability.js so might | |||||
| # have non-printing text (e.g. Comment or CData). In this case, we | |||||
| # keep the structure, but ensure that the string is empty. | |||||
| element = type(element)("") | |||||
| else: | |||||
| plain_text = element.string | |||||
| plain_text = normalize_text(plain_text) | |||||
| element = type(element)(plain_text) | |||||
| else: | |||||
| # If not a leaf node or leaf type call recursively on child nodes, replacing | |||||
| element.contents = plain_elements(element.contents, content_digests, node_indexes) | |||||
| return element | |||||
| def add_node_indexes(element, node_index="0"): | |||||
| # Can't add attributes to string types | |||||
| if is_text(element): | |||||
| return element | |||||
| # Add index to current element | |||||
| element["data-node-index"] = node_index | |||||
| # Add index to child elements | |||||
| for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1): | |||||
| # Can't add attributes to leaf string types | |||||
| child_index = "{stem}.{local}".format(stem=node_index, local=local_idx) | |||||
| add_node_indexes(child, node_index=child_index) | |||||
| return element | |||||
| def normalize_text(text): | |||||
| """Normalize unicode and whitespace.""" | |||||
| # Normalize unicode first to try and standardize whitespace characters as much as possible before normalizing them | |||||
| text = strip_control_characters(text) | |||||
| text = normalize_unicode(text) | |||||
| text = normalize_whitespace(text) | |||||
| return text | |||||
| def strip_control_characters(text): | |||||
| """Strip out unicode control characters which might break the parsing.""" | |||||
| # Unicode control characters | |||||
| # [Cc]: Other, Control [includes new lines] | |||||
| # [Cf]: Other, Format | |||||
| # [Cn]: Other, Not Assigned | |||||
| # [Co]: Other, Private Use | |||||
| # [Cs]: Other, Surrogate | |||||
| control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"} | |||||
| retained_chars = ["\t", "\n", "\r", "\f"] | |||||
| # Remove non-printing control characters | |||||
| return "".join( | |||||
| [ | |||||
| "" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char | |||||
| for char in text | |||||
| ] | |||||
| ) | |||||
| def normalize_unicode(text): | |||||
| """Normalize unicode such that things that are visually equivalent map to the same unicode string where possible.""" | |||||
| normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC" | |||||
| text = unicodedata.normalize(normal_form, text) | |||||
| return text | |||||
| def normalize_whitespace(text): | |||||
| """Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed.""" | |||||
| text = regex.sub(r"\s+", " ", text) | |||||
| # Remove leading and trailing whitespace | |||||
| text = text.strip() | |||||
| return text | |||||
| def is_leaf(element): | |||||
| return element.name in {"p", "li"} | |||||
| def is_text(element): | |||||
| return isinstance(element, NavigableString) | |||||
| def is_non_printing(element): | |||||
| return any(isinstance(element, _e) for _e in [Comment, CData]) | |||||
| def add_content_digest(element): | |||||
| if not is_text(element): | |||||
| element["data-content-digest"] = content_digest(element) | |||||
| return element | |||||
| def content_digest(element): | |||||
| digest: Any | |||||
| if is_text(element): | |||||
| # Hash | |||||
| trimmed_string = element.string.strip() | |||||
| if trimmed_string == "": | |||||
| digest = "" | |||||
| else: | |||||
| digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest() | |||||
| else: | |||||
| contents = element.contents | |||||
| num_contents = len(contents) | |||||
| if num_contents == 0: | |||||
| # No hash when no child elements exist | |||||
| digest = "" | |||||
| elif num_contents == 1: | |||||
| # If single child, use digest of child | |||||
| digest = content_digest(contents[0]) | |||||
| else: | |||||
| # Build content digest from the "non-empty" digests of child nodes | |||||
| digest = hashlib.sha256() | |||||
| child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents])) | |||||
| for child in child_digests: | |||||
| digest.update(child.encode("utf-8")) | |||||
| digest = digest.hexdigest() | |||||
| return digest | |||||
| def get_image_upload_file_ids(content): | |||||
| pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)" | |||||
| matches = re.findall(pattern, content) | |||||
| image_upload_file_ids = [] | |||||
| for match in matches: | |||||
| if match[1] == "file-preview": | |||||
| content_pattern = r"files/([^/]+)/file-preview" | |||||
| else: | |||||
| content_pattern = r"files/([^/]+)/image-preview" | |||||
| content_match = re.search(content_pattern, match[0]) | |||||
| if content_match: | |||||
| image_upload_file_id = content_match.group(1) | |||||
| image_upload_file_ids.append(image_upload_file_id) | |||||
| return image_upload_file_ids |
| {"not_installed": [], "plugin_install_failed": []} |