| @@ -6,31 +6,30 @@ from typing import TYPE_CHECKING, Any, Optional | |||
| if TYPE_CHECKING: | |||
| from models.model import File | |||
| from core.tools.__base.tool_runtime import ToolRuntime | |||
| from core.datasource.__base.datasource_runtime import DatasourceRuntime | |||
| from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceProviderType | |||
| from core.tools.entities.tool_entities import ( | |||
| ToolEntity, | |||
| ToolInvokeMessage, | |||
| ToolParameter, | |||
| ToolProviderType, | |||
| ) | |||
| class Tool(ABC): | |||
| class Datasource(ABC): | |||
| """ | |||
| The base class of a tool | |||
| The base class of a datasource | |||
| """ | |||
| entity: ToolEntity | |||
| runtime: ToolRuntime | |||
| entity: DatasourceEntity | |||
| runtime: DatasourceRuntime | |||
| def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None: | |||
| def __init__(self, entity: DatasourceEntity, runtime: DatasourceRuntime) -> None: | |||
| self.entity = entity | |||
| self.runtime = runtime | |||
| def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool": | |||
| def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "Datasource": | |||
| """ | |||
| fork a new tool with metadata | |||
| :return: the new tool | |||
| fork a new datasource with metadata | |||
| :return: the new datasource | |||
| """ | |||
| return self.__class__( | |||
| entity=self.entity.model_copy(), | |||
| @@ -38,9 +37,9 @@ class Tool(ABC): | |||
| ) | |||
| @abstractmethod | |||
| def tool_provider_type(self) -> ToolProviderType: | |||
| def datasource_provider_type(self) -> DatasourceProviderType: | |||
| """ | |||
| get the tool provider type | |||
| get the datasource provider type | |||
| :return: the tool provider type | |||
| """ | |||
| @@ -4,12 +4,13 @@ from openai import BaseModel | |||
| from pydantic import Field | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.datasource.entities.datasource_entities import DatasourceInvokeFrom | |||
| from core.tools.entities.tool_entities import ToolInvokeFrom | |||
| class ToolRuntime(BaseModel): | |||
| class DatasourceRuntime(BaseModel): | |||
| """ | |||
| Meta data of a tool call processing | |||
| Meta data of a datasource call processing | |||
| """ | |||
| tenant_id: str | |||
| @@ -20,17 +21,17 @@ class ToolRuntime(BaseModel): | |||
| runtime_parameters: dict[str, Any] = Field(default_factory=dict) | |||
| class FakeToolRuntime(ToolRuntime): | |||
| class FakeDatasourceRuntime(DatasourceRuntime): | |||
| """ | |||
| Fake tool runtime for testing | |||
| Fake datasource runtime for testing | |||
| """ | |||
| def __init__(self): | |||
| super().__init__( | |||
| tenant_id="fake_tenant_id", | |||
| tool_id="fake_tool_id", | |||
| datasource_id="fake_datasource_id", | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| tool_invoke_from=ToolInvokeFrom.AGENT, | |||
| datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE, | |||
| credentials={}, | |||
| runtime_parameters={}, | |||
| ) | |||
| @@ -36,9 +36,9 @@ from models.enums import CreatedByRole | |||
| from models.model import Message, MessageFile | |||
| class ToolEngine: | |||
| class DatasourceEngine: | |||
| """ | |||
| Tool runtime engine take care of the tool executions. | |||
| Datasource runtime engine take care of the datasource executions. | |||
| """ | |||
| @staticmethod | |||
| @@ -1,7 +1,9 @@ | |||
| from collections.abc import Generator | |||
| from typing import Any, Optional | |||
| from core.datasource.__base.datasource_runtime import DatasourceRuntime | |||
| from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceParameter, DatasourceProviderType | |||
| from core.plugin.manager.datasource import PluginDatasourceManager | |||
| from core.plugin.manager.tool import PluginToolManager | |||
| from core.plugin.utils.converter import convert_parameters_to_plugin_format | |||
| from core.tools.__base.tool import Tool | |||
| @@ -9,7 +11,7 @@ from core.tools.__base.tool_runtime import ToolRuntime | |||
| from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType | |||
| class DatasourceTool(Tool): | |||
| class DatasourcePlugin(Datasource): | |||
| tenant_id: str | |||
| icon: str | |||
| plugin_unique_identifier: str | |||
| @@ -31,53 +33,45 @@ class DatasourceTool(Tool): | |||
| self, | |||
| user_id: str, | |||
| datasource_parameters: dict[str, Any], | |||
| conversation_id: Optional[str] = None, | |||
| rag_pipeline_id: Optional[str] = None, | |||
| message_id: Optional[str] = None, | |||
| ) -> Generator[ToolInvokeMessage, None, None]: | |||
| manager = PluginToolManager() | |||
| manager = PluginDatasourceManager() | |||
| datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) | |||
| yield from manager.invoke_first_step( | |||
| tenant_id=self.tenant_id, | |||
| user_id=user_id, | |||
| tool_provider=self.entity.identity.provider, | |||
| tool_name=self.entity.identity.name, | |||
| datasource_provider=self.entity.identity.provider, | |||
| datasource_name=self.entity.identity.name, | |||
| credentials=self.runtime.credentials, | |||
| tool_parameters=tool_parameters, | |||
| conversation_id=conversation_id, | |||
| app_id=app_id, | |||
| message_id=message_id, | |||
| datasource_parameters=datasource_parameters, | |||
| rag_pipeline_id=rag_pipeline_id, | |||
| ) | |||
| def _invoke_second_step( | |||
| self, | |||
| user_id: str, | |||
| datasource_parameters: dict[str, Any], | |||
| conversation_id: Optional[str] = None, | |||
| rag_pipeline_id: Optional[str] = None, | |||
| message_id: Optional[str] = None, | |||
| ) -> Generator[ToolInvokeMessage, None, None]: | |||
| manager = PluginToolManager() | |||
| manager = PluginDatasourceManager() | |||
| datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) | |||
| yield from manager.invoke( | |||
| tenant_id=self.tenant_id, | |||
| user_id=user_id, | |||
| tool_provider=self.entity.identity.provider, | |||
| tool_name=self.entity.identity.name, | |||
| datasource_provider=self.entity.identity.provider, | |||
| datasource_name=self.entity.identity.name, | |||
| credentials=self.runtime.credentials, | |||
| tool_parameters=tool_parameters, | |||
| conversation_id=conversation_id, | |||
| app_id=app_id, | |||
| message_id=message_id, | |||
| datasource_parameters=datasource_parameters, | |||
| rag_pipeline_id=rag_pipeline_id, | |||
| ) | |||
| def fork_tool_runtime(self, runtime: ToolRuntime) -> "DatasourceTool": | |||
| return DatasourceTool( | |||
| def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin": | |||
| return DatasourcePlugin( | |||
| entity=self.entity, | |||
| runtime=runtime, | |||
| tenant_id=self.tenant_id, | |||
| @@ -87,9 +81,7 @@ class DatasourceTool(Tool): | |||
| def get_runtime_parameters( | |||
| self, | |||
| conversation_id: Optional[str] = None, | |||
| app_id: Optional[str] = None, | |||
| message_id: Optional[str] = None, | |||
| rag_pipeline_id: Optional[str] = None, | |||
| ) -> list[DatasourceParameter]: | |||
| """ | |||
| get the runtime parameters | |||
| @@ -100,16 +92,14 @@ class DatasourceTool(Tool): | |||
| if self.runtime_parameters is not None: | |||
| return self.runtime_parameters | |||
| manager = PluginToolManager() | |||
| manager = PluginDatasourceManager() | |||
| self.runtime_parameters = manager.get_runtime_parameters( | |||
| tenant_id=self.tenant_id, | |||
| user_id="", | |||
| provider=self.entity.identity.provider, | |||
| tool=self.entity.identity.name, | |||
| datasource=self.entity.identity.name, | |||
| credentials=self.runtime.credentials, | |||
| conversation_id=conversation_id, | |||
| app_id=app_id, | |||
| message_id=message_id, | |||
| rag_pipeline_id=rag_pipeline_id, | |||
| ) | |||
| return self.runtime_parameters | |||
| @@ -1,199 +0,0 @@ | |||
| import threading | |||
| from typing import Any | |||
| from flask import Flask, current_app | |||
| from pydantic import BaseModel, Field | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rag.models.document import Document as RagDocument | |||
| from core.rag.rerank.rerank_model import RerankModelRunner | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, Document, DocumentSegment | |||
| default_retrieval_model: dict[str, Any] = { | |||
| "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| "reranking_enable": False, | |||
| "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, | |||
| "top_k": 2, | |||
| "score_threshold_enabled": False, | |||
| } | |||
| class DatasetMultiRetrieverToolInput(BaseModel): | |||
| query: str = Field(..., description="dataset multi retriever and rerank") | |||
| class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): | |||
| """Tool for querying multi dataset.""" | |||
| name: str = "dataset_" | |||
| args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput | |||
| description: str = "dataset multi retriever and rerank. " | |||
| dataset_ids: list[str] | |||
| reranking_provider_name: str | |||
| reranking_model_name: str | |||
| @classmethod | |||
| def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs): | |||
| return cls( | |||
| name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs | |||
| ) | |||
| def _run(self, query: str) -> str: | |||
| threads = [] | |||
| all_documents: list[RagDocument] = [] | |||
| for dataset_id in self.dataset_ids: | |||
| retrieval_thread = threading.Thread( | |||
| target=self._retriever, | |||
| kwargs={ | |||
| "flask_app": current_app._get_current_object(), # type: ignore | |||
| "dataset_id": dataset_id, | |||
| "query": query, | |||
| "all_documents": all_documents, | |||
| "hit_callbacks": self.hit_callbacks, | |||
| }, | |||
| ) | |||
| threads.append(retrieval_thread) | |||
| retrieval_thread.start() | |||
| for thread in threads: | |||
| thread.join() | |||
| # do rerank for searched documents | |||
| model_manager = ModelManager() | |||
| rerank_model_instance = model_manager.get_model_instance( | |||
| tenant_id=self.tenant_id, | |||
| provider=self.reranking_provider_name, | |||
| model_type=ModelType.RERANK, | |||
| model=self.reranking_model_name, | |||
| ) | |||
| rerank_runner = RerankModelRunner(rerank_model_instance) | |||
| all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k) | |||
| for hit_callback in self.hit_callbacks: | |||
| hit_callback.on_tool_end(all_documents) | |||
| document_score_list = {} | |||
| for item in all_documents: | |||
| if item.metadata and item.metadata.get("score"): | |||
| document_score_list[item.metadata["doc_id"]] = item.metadata["score"] | |||
| document_context_list = [] | |||
| index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] | |||
| segments = DocumentSegment.query.filter( | |||
| DocumentSegment.dataset_id.in_(self.dataset_ids), | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.status == "completed", | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.index_node_id.in_(index_node_ids), | |||
| ).all() | |||
| if segments: | |||
| index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} | |||
| sorted_segments = sorted( | |||
| segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) | |||
| ) | |||
| for segment in sorted_segments: | |||
| if segment.answer: | |||
| document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") | |||
| else: | |||
| document_context_list.append(segment.get_sign_content()) | |||
| if self.return_resource: | |||
| context_list = [] | |||
| resource_number = 1 | |||
| for segment in sorted_segments: | |||
| dataset = Dataset.query.filter_by(id=segment.dataset_id).first() | |||
| document = Document.query.filter( | |||
| Document.id == segment.document_id, | |||
| Document.enabled == True, | |||
| Document.archived == False, | |||
| ).first() | |||
| if dataset and document: | |||
| source = { | |||
| "position": resource_number, | |||
| "dataset_id": dataset.id, | |||
| "dataset_name": dataset.name, | |||
| "document_id": document.id, | |||
| "document_name": document.name, | |||
| "data_source_type": document.data_source_type, | |||
| "segment_id": segment.id, | |||
| "retriever_from": self.retriever_from, | |||
| "score": document_score_list.get(segment.index_node_id, None), | |||
| "doc_metadata": document.doc_metadata, | |||
| } | |||
| if self.retriever_from == "dev": | |||
| source["hit_count"] = segment.hit_count | |||
| source["word_count"] = segment.word_count | |||
| source["segment_position"] = segment.position | |||
| source["index_node_hash"] = segment.index_node_hash | |||
| if segment.answer: | |||
| source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" | |||
| else: | |||
| source["content"] = segment.content | |||
| context_list.append(source) | |||
| resource_number += 1 | |||
| for hit_callback in self.hit_callbacks: | |||
| hit_callback.return_retriever_resource_info(context_list) | |||
| return str("\n".join(document_context_list)) | |||
| return "" | |||
| raise RuntimeError("not segments found") | |||
| def _retriever( | |||
| self, | |||
| flask_app: Flask, | |||
| dataset_id: str, | |||
| query: str, | |||
| all_documents: list, | |||
| hit_callbacks: list[DatasetIndexToolCallbackHandler], | |||
| ): | |||
| with flask_app.app_context(): | |||
| dataset = ( | |||
| db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() | |||
| ) | |||
| if not dataset: | |||
| return [] | |||
| for hit_callback in hit_callbacks: | |||
| hit_callback.on_query(query, dataset.id) | |||
| # get retrieval model , if the model is not setting , using default | |||
| retrieval_model = dataset.retrieval_model or default_retrieval_model | |||
| if dataset.indexing_technique == "economy": | |||
| # use keyword table query | |||
| documents = RetrievalService.retrieve( | |||
| retrieval_method="keyword_search", | |||
| dataset_id=dataset.id, | |||
| query=query, | |||
| top_k=retrieval_model.get("top_k") or 2, | |||
| ) | |||
| if documents: | |||
| all_documents.extend(documents) | |||
| else: | |||
| if self.top_k > 0: | |||
| # retrieval source | |||
| documents = RetrievalService.retrieve( | |||
| retrieval_method=retrieval_model["search_method"], | |||
| dataset_id=dataset.id, | |||
| query=query, | |||
| top_k=retrieval_model.get("top_k") or 2, | |||
| score_threshold=retrieval_model.get("score_threshold", 0.0) | |||
| if retrieval_model["score_threshold_enabled"] | |||
| else 0.0, | |||
| reranking_model=retrieval_model.get("reranking_model", None) | |||
| if retrieval_model["reranking_enable"] | |||
| else None, | |||
| reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", | |||
| weights=retrieval_model.get("weights", None), | |||
| ) | |||
| all_documents.extend(documents) | |||
| @@ -1,33 +0,0 @@ | |||
| from abc import abstractmethod | |||
| from typing import Any, Optional | |||
| from msal_extensions.persistence import ABC # type: ignore | |||
| from pydantic import BaseModel, ConfigDict | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| class DatasetRetrieverBaseTool(BaseModel, ABC): | |||
| """Tool for querying a Dataset.""" | |||
| name: str = "dataset" | |||
| description: str = "use this to retrieve a dataset. " | |||
| tenant_id: str | |||
| top_k: int = 2 | |||
| score_threshold: Optional[float] = None | |||
| hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] | |||
| return_resource: bool | |||
| retriever_from: str | |||
| model_config = ConfigDict(arbitrary_types_allowed=True) | |||
| @abstractmethod | |||
| def _run( | |||
| self, | |||
| *args: Any, | |||
| **kwargs: Any, | |||
| ) -> Any: | |||
| """Use the tool. | |||
| Add run_manager: Optional[CallbackManagerForToolRun] = None | |||
| to child implementations to enable tracing, | |||
| """ | |||
| @@ -1,202 +0,0 @@ | |||
| from typing import Any | |||
| from pydantic import BaseModel, Field | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rag.entities.context_entities import DocumentContext | |||
| from core.rag.models.document import Document as RetrievalDocument | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset | |||
| from models.dataset import Document as DatasetDocument | |||
| from services.external_knowledge_service import ExternalDatasetService | |||
| default_retrieval_model = { | |||
| "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| "reranking_enable": False, | |||
| "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, | |||
| "reranking_mode": "reranking_model", | |||
| "top_k": 2, | |||
| "score_threshold_enabled": False, | |||
| } | |||
| class DatasetRetrieverToolInput(BaseModel): | |||
| query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.") | |||
| class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| """Tool for querying a Dataset.""" | |||
| name: str = "dataset" | |||
| args_schema: type[BaseModel] = DatasetRetrieverToolInput | |||
| description: str = "use this to retrieve a dataset. " | |||
| dataset_id: str | |||
| @classmethod | |||
| def from_dataset(cls, dataset: Dataset, **kwargs): | |||
| description = dataset.description | |||
| if not description: | |||
| description = "useful for when you want to answer queries about the " + dataset.name | |||
| description = description.replace("\n", "").replace("\r", "") | |||
| return cls( | |||
| name=f"dataset_{dataset.id.replace('-', '_')}", | |||
| tenant_id=dataset.tenant_id, | |||
| dataset_id=dataset.id, | |||
| description=description, | |||
| **kwargs, | |||
| ) | |||
| def _run(self, query: str) -> str: | |||
| dataset = ( | |||
| db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first() | |||
| ) | |||
| if not dataset: | |||
| return "" | |||
| for hit_callback in self.hit_callbacks: | |||
| hit_callback.on_query(query, dataset.id) | |||
| if dataset.provider == "external": | |||
| results = [] | |||
| external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( | |||
| tenant_id=dataset.tenant_id, | |||
| dataset_id=dataset.id, | |||
| query=query, | |||
| external_retrieval_parameters=dataset.retrieval_model, | |||
| ) | |||
| for external_document in external_documents: | |||
| document = RetrievalDocument( | |||
| page_content=external_document.get("content"), | |||
| metadata=external_document.get("metadata"), | |||
| provider="external", | |||
| ) | |||
| if document.metadata is not None: | |||
| document.metadata["score"] = external_document.get("score") | |||
| document.metadata["title"] = external_document.get("title") | |||
| document.metadata["dataset_id"] = dataset.id | |||
| document.metadata["dataset_name"] = dataset.name | |||
| results.append(document) | |||
| # deal with external documents | |||
| context_list = [] | |||
| for position, item in enumerate(results, start=1): | |||
| if item.metadata is not None: | |||
| source = { | |||
| "position": position, | |||
| "dataset_id": item.metadata.get("dataset_id"), | |||
| "dataset_name": item.metadata.get("dataset_name"), | |||
| "document_name": item.metadata.get("title"), | |||
| "data_source_type": "external", | |||
| "retriever_from": self.retriever_from, | |||
| "score": item.metadata.get("score"), | |||
| "title": item.metadata.get("title"), | |||
| "content": item.page_content, | |||
| } | |||
| context_list.append(source) | |||
| for hit_callback in self.hit_callbacks: | |||
| hit_callback.return_retriever_resource_info(context_list) | |||
| return str("\n".join([item.page_content for item in results])) | |||
| else: | |||
| # get retrieval model , if the model is not setting , using default | |||
| retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model | |||
| if dataset.indexing_technique == "economy": | |||
| # use keyword table query | |||
| documents = RetrievalService.retrieve( | |||
| retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k | |||
| ) | |||
| return str("\n".join([document.page_content for document in documents])) | |||
| else: | |||
| if self.top_k > 0: | |||
| # retrieval source | |||
| documents = RetrievalService.retrieve( | |||
| retrieval_method=retrieval_model.get("search_method", "semantic_search"), | |||
| dataset_id=dataset.id, | |||
| query=query, | |||
| top_k=self.top_k, | |||
| score_threshold=retrieval_model.get("score_threshold", 0.0) | |||
| if retrieval_model["score_threshold_enabled"] | |||
| else 0.0, | |||
| reranking_model=retrieval_model.get("reranking_model") | |||
| if retrieval_model["reranking_enable"] | |||
| else None, | |||
| reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", | |||
| weights=retrieval_model.get("weights"), | |||
| ) | |||
| else: | |||
| documents = [] | |||
| for hit_callback in self.hit_callbacks: | |||
| hit_callback.on_tool_end(documents) | |||
| document_score_list = {} | |||
| if dataset.indexing_technique != "economy": | |||
| for item in documents: | |||
| if item.metadata is not None and item.metadata.get("score"): | |||
| document_score_list[item.metadata["doc_id"]] = item.metadata["score"] | |||
| document_context_list = [] | |||
| records = RetrievalService.format_retrieval_documents(documents) | |||
| if records: | |||
| for record in records: | |||
| segment = record.segment | |||
| if segment.answer: | |||
| document_context_list.append( | |||
| DocumentContext( | |||
| content=f"question:{segment.get_sign_content()} answer:{segment.answer}", | |||
| score=record.score, | |||
| ) | |||
| ) | |||
| else: | |||
| document_context_list.append( | |||
| DocumentContext( | |||
| content=segment.get_sign_content(), | |||
| score=record.score, | |||
| ) | |||
| ) | |||
| retrieval_resource_list = [] | |||
| if self.return_resource: | |||
| for record in records: | |||
| segment = record.segment | |||
| dataset = Dataset.query.filter_by(id=segment.dataset_id).first() | |||
| document = DatasetDocument.query.filter( | |||
| DatasetDocument.id == segment.document_id, | |||
| DatasetDocument.enabled == True, | |||
| DatasetDocument.archived == False, | |||
| ).first() | |||
| if dataset and document: | |||
| source = { | |||
| "dataset_id": dataset.id, | |||
| "dataset_name": dataset.name, | |||
| "document_id": document.id, # type: ignore | |||
| "document_name": document.name, # type: ignore | |||
| "data_source_type": document.data_source_type, # type: ignore | |||
| "segment_id": segment.id, | |||
| "retriever_from": self.retriever_from, | |||
| "score": record.score or 0.0, | |||
| "doc_metadata": document.doc_metadata, # type: ignore | |||
| } | |||
| if self.retriever_from == "dev": | |||
| source["hit_count"] = segment.hit_count | |||
| source["word_count"] = segment.word_count | |||
| source["segment_position"] = segment.position | |||
| source["index_node_hash"] = segment.index_node_hash | |||
| if segment.answer: | |||
| source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" | |||
| else: | |||
| source["content"] = segment.content | |||
| retrieval_resource_list.append(source) | |||
| if self.return_resource and retrieval_resource_list: | |||
| retrieval_resource_list = sorted( | |||
| retrieval_resource_list, | |||
| key=lambda x: x.get("score") or 0.0, | |||
| reverse=True, | |||
| ) | |||
| for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore | |||
| item["position"] = position # type: ignore | |||
| for hit_callback in self.hit_callbacks: | |||
| hit_callback.return_retriever_resource_info(retrieval_resource_list) | |||
| if document_context_list: | |||
| document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) | |||
| return str("\n".join([document_context.content for document_context in document_context_list])) | |||
| return "" | |||
| @@ -1,134 +0,0 @@ | |||
| from collections.abc import Generator | |||
| from typing import Any, Optional | |||
| from core.app.app_config.entities import DatasetRetrieveConfigEntity | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | |||
| from core.tools.__base.tool import Tool | |||
| from core.tools.__base.tool_runtime import ToolRuntime | |||
| from core.tools.entities.common_entities import I18nObject | |||
| from core.tools.entities.tool_entities import ( | |||
| ToolDescription, | |||
| ToolEntity, | |||
| ToolIdentity, | |||
| ToolInvokeMessage, | |||
| ToolParameter, | |||
| ToolProviderType, | |||
| ) | |||
| from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool | |||
| class DatasetRetrieverTool(Tool): | |||
| retrieval_tool: DatasetRetrieverBaseTool | |||
| def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None: | |||
| super().__init__(entity, runtime) | |||
| self.retrieval_tool = retrieval_tool | |||
| @staticmethod | |||
| def get_dataset_tools( | |||
| tenant_id: str, | |||
| dataset_ids: list[str], | |||
| retrieve_config: DatasetRetrieveConfigEntity | None, | |||
| return_resource: bool, | |||
| invoke_from: InvokeFrom, | |||
| hit_callback: DatasetIndexToolCallbackHandler, | |||
| ) -> list["DatasetRetrieverTool"]: | |||
| """ | |||
| get dataset tool | |||
| """ | |||
| # check if retrieve_config is valid | |||
| if dataset_ids is None or len(dataset_ids) == 0: | |||
| return [] | |||
| if retrieve_config is None: | |||
| return [] | |||
| feature = DatasetRetrieval() | |||
| # save original retrieve strategy, and set retrieve strategy to SINGLE | |||
| # Agent only support SINGLE mode | |||
| original_retriever_mode = retrieve_config.retrieve_strategy | |||
| retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE | |||
| retrieval_tools = feature.to_dataset_retriever_tool( | |||
| tenant_id=tenant_id, | |||
| dataset_ids=dataset_ids, | |||
| retrieve_config=retrieve_config, | |||
| return_resource=return_resource, | |||
| invoke_from=invoke_from, | |||
| hit_callback=hit_callback, | |||
| ) | |||
| if retrieval_tools is None or len(retrieval_tools) == 0: | |||
| return [] | |||
| # restore retrieve strategy | |||
| retrieve_config.retrieve_strategy = original_retriever_mode | |||
| # convert retrieval tools to Tools | |||
| tools = [] | |||
| for retrieval_tool in retrieval_tools: | |||
| tool = DatasetRetrieverTool( | |||
| retrieval_tool=retrieval_tool, | |||
| entity=ToolEntity( | |||
| identity=ToolIdentity( | |||
| provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="") | |||
| ), | |||
| parameters=[], | |||
| description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description), | |||
| ), | |||
| runtime=ToolRuntime(tenant_id=tenant_id), | |||
| ) | |||
| tools.append(tool) | |||
| return tools | |||
| def get_runtime_parameters( | |||
| self, | |||
| conversation_id: Optional[str] = None, | |||
| app_id: Optional[str] = None, | |||
| message_id: Optional[str] = None, | |||
| ) -> list[ToolParameter]: | |||
| return [ | |||
| ToolParameter( | |||
| name="query", | |||
| label=I18nObject(en_US="", zh_Hans=""), | |||
| human_description=I18nObject(en_US="", zh_Hans=""), | |||
| type=ToolParameter.ToolParameterType.STRING, | |||
| form=ToolParameter.ToolParameterForm.LLM, | |||
| llm_description="Query for the dataset to be used to retrieve the dataset.", | |||
| required=True, | |||
| default="", | |||
| placeholder=I18nObject(en_US="", zh_Hans=""), | |||
| ), | |||
| ] | |||
| def tool_provider_type(self) -> ToolProviderType: | |||
| return ToolProviderType.DATASET_RETRIEVAL | |||
| def _invoke( | |||
| self, | |||
| user_id: str, | |||
| tool_parameters: dict[str, Any], | |||
| conversation_id: Optional[str] = None, | |||
| app_id: Optional[str] = None, | |||
| message_id: Optional[str] = None, | |||
| ) -> Generator[ToolInvokeMessage, None, None]: | |||
| """ | |||
| invoke dataset retriever tool | |||
| """ | |||
| query = tool_parameters.get("query") | |||
| if not query: | |||
| yield self.create_text_message(text="please input query") | |||
| else: | |||
| # invoke dataset retriever tool | |||
| result = self.retrieval_tool._run(query=query) | |||
| yield self.create_text_message(text=result) | |||
| def validate_credentials( | |||
| self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False | |||
| ) -> str | None: | |||
| """ | |||
| validate the credentials for dataset retriever tool | |||
| """ | |||
| pass | |||
| @@ -1,169 +0,0 @@ | |||
| """ | |||
| For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc. | |||
| Therefore, a model manager is needed to list/invoke/validate models. | |||
| """ | |||
| import json | |||
| from typing import Optional, cast | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.llm_entities import LLMResult | |||
| from core.model_runtime.entities.message_entities import PromptMessage | |||
| from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType | |||
| from core.model_runtime.errors.invoke import ( | |||
| InvokeAuthorizationError, | |||
| InvokeBadRequestError, | |||
| InvokeConnectionError, | |||
| InvokeRateLimitError, | |||
| InvokeServerUnavailableError, | |||
| ) | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from extensions.ext_database import db | |||
| from models.tools import ToolModelInvoke | |||
| class InvokeModelError(Exception): | |||
| pass | |||
| class ModelInvocationUtils: | |||
| @staticmethod | |||
| def get_max_llm_context_tokens( | |||
| tenant_id: str, | |||
| ) -> int: | |||
| """ | |||
| get max llm context tokens of the model | |||
| """ | |||
| model_manager = ModelManager() | |||
| model_instance = model_manager.get_default_model_instance( | |||
| tenant_id=tenant_id, | |||
| model_type=ModelType.LLM, | |||
| ) | |||
| if not model_instance: | |||
| raise InvokeModelError("Model not found") | |||
| llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) | |||
| schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) | |||
| if not schema: | |||
| raise InvokeModelError("No model schema found") | |||
| max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) | |||
| if max_tokens is None: | |||
| return 2048 | |||
| return max_tokens | |||
| @staticmethod | |||
| def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int: | |||
| """ | |||
| calculate tokens from prompt messages and model parameters | |||
| """ | |||
| # get model instance | |||
| model_manager = ModelManager() | |||
| model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM) | |||
| if not model_instance: | |||
| raise InvokeModelError("Model not found") | |||
| # get tokens | |||
| tokens = model_instance.get_llm_num_tokens(prompt_messages) | |||
| return tokens | |||
| @staticmethod | |||
| def invoke( | |||
| user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage] | |||
| ) -> LLMResult: | |||
| """ | |||
| invoke model with parameters in user's own context | |||
| :param user_id: user id | |||
| :param tenant_id: tenant id, the tenant id of the creator of the tool | |||
| :param tool_type: tool type | |||
| :param tool_name: tool name | |||
| :param prompt_messages: prompt messages | |||
| :return: AssistantPromptMessage | |||
| """ | |||
| # get model manager | |||
| model_manager = ModelManager() | |||
| # get model instance | |||
| model_instance = model_manager.get_default_model_instance( | |||
| tenant_id=tenant_id, | |||
| model_type=ModelType.LLM, | |||
| ) | |||
| # get prompt tokens | |||
| prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) | |||
| model_parameters = { | |||
| "temperature": 0.8, | |||
| "top_p": 0.8, | |||
| } | |||
| # create tool model invoke | |||
| tool_model_invoke = ToolModelInvoke( | |||
| user_id=user_id, | |||
| tenant_id=tenant_id, | |||
| provider=model_instance.provider, | |||
| tool_type=tool_type, | |||
| tool_name=tool_name, | |||
| model_parameters=json.dumps(model_parameters), | |||
| prompt_messages=json.dumps(jsonable_encoder(prompt_messages)), | |||
| model_response="", | |||
| prompt_tokens=prompt_tokens, | |||
| answer_tokens=0, | |||
| answer_unit_price=0, | |||
| answer_price_unit=0, | |||
| provider_response_latency=0, | |||
| total_price=0, | |||
| currency="USD", | |||
| ) | |||
| db.session.add(tool_model_invoke) | |||
| db.session.commit() | |||
| try: | |||
| response: LLMResult = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| model_parameters=model_parameters, | |||
| tools=[], | |||
| stop=[], | |||
| stream=False, | |||
| user=user_id, | |||
| callbacks=[], | |||
| ), | |||
| ) | |||
| except InvokeRateLimitError as e: | |||
| raise InvokeModelError(f"Invoke rate limit error: {e}") | |||
| except InvokeBadRequestError as e: | |||
| raise InvokeModelError(f"Invoke bad request error: {e}") | |||
| except InvokeConnectionError as e: | |||
| raise InvokeModelError(f"Invoke connection error: {e}") | |||
| except InvokeAuthorizationError as e: | |||
| raise InvokeModelError("Invoke authorization error") | |||
| except InvokeServerUnavailableError as e: | |||
| raise InvokeModelError(f"Invoke server unavailable error: {e}") | |||
| except Exception as e: | |||
| raise InvokeModelError(f"Invoke error: {e}") | |||
| # update tool model invoke | |||
| tool_model_invoke.model_response = response.message.content | |||
| if response.usage: | |||
| tool_model_invoke.answer_tokens = response.usage.completion_tokens | |||
| tool_model_invoke.answer_unit_price = response.usage.completion_unit_price | |||
| tool_model_invoke.answer_price_unit = response.usage.completion_price_unit | |||
| tool_model_invoke.provider_response_latency = response.usage.latency | |||
| tool_model_invoke.total_price = response.usage.total_price | |||
| tool_model_invoke.currency = response.usage.currency | |||
| db.session.commit() | |||
| return response | |||
| @@ -1,17 +0,0 @@ | |||
| import re | |||
| def get_image_upload_file_ids(content): | |||
| pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)" | |||
| matches = re.findall(pattern, content) | |||
| image_upload_file_ids = [] | |||
| for match in matches: | |||
| if match[1] == "file-preview": | |||
| content_pattern = r"files/([^/]+)/file-preview" | |||
| else: | |||
| content_pattern = r"files/([^/]+)/image-preview" | |||
| content_match = re.search(content_pattern, match[0]) | |||
| if content_match: | |||
| image_upload_file_id = content_match.group(1) | |||
| image_upload_file_ids.append(image_upload_file_id) | |||
| return image_upload_file_ids | |||
| @@ -1,375 +0,0 @@ | |||
| import hashlib | |||
| import json | |||
| import mimetypes | |||
| import os | |||
| import re | |||
| import site | |||
| import subprocess | |||
| import tempfile | |||
| import unicodedata | |||
| from contextlib import contextmanager | |||
| from pathlib import Path | |||
| from typing import Any, Literal, Optional, cast | |||
| from urllib.parse import unquote | |||
| import chardet | |||
| import cloudscraper # type: ignore | |||
| from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore | |||
| from regex import regex # type: ignore | |||
| from core.helper import ssrf_proxy | |||
| from core.rag.extractor import extract_processor | |||
| from core.rag.extractor.extract_processor import ExtractProcessor | |||
| FULL_TEMPLATE = """ | |||
| TITLE: {title} | |||
| AUTHORS: {authors} | |||
| PUBLISH DATE: {publish_date} | |||
| TOP_IMAGE_URL: {top_image} | |||
| TEXT: | |||
| {text} | |||
| """ | |||
| def page_result(text: str, cursor: int, max_length: int) -> str: | |||
| """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" | |||
| return text[cursor : cursor + max_length] | |||
| def get_url(url: str, user_agent: Optional[str] = None) -> str: | |||
| """Fetch URL and return the contents as a string.""" | |||
| headers = { | |||
| "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)" | |||
| " Chrome/91.0.4472.124 Safari/537.36" | |||
| } | |||
| if user_agent: | |||
| headers["User-Agent"] = user_agent | |||
| main_content_type = None | |||
| supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] | |||
| response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10)) | |||
| if response.status_code == 200: | |||
| # check content-type | |||
| content_type = response.headers.get("Content-Type") | |||
| if content_type: | |||
| main_content_type = response.headers.get("Content-Type").split(";")[0].strip() | |||
| else: | |||
| content_disposition = response.headers.get("Content-Disposition", "") | |||
| filename_match = re.search(r'filename="([^"]+)"', content_disposition) | |||
| if filename_match: | |||
| filename = unquote(filename_match.group(1)) | |||
| extension = re.search(r"\.(\w+)$", filename) | |||
| if extension: | |||
| main_content_type = mimetypes.guess_type(filename)[0] | |||
| if main_content_type not in supported_content_types: | |||
| return "Unsupported content-type [{}] of URL.".format(main_content_type) | |||
| if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: | |||
| return cast(str, ExtractProcessor.load_from_url(url, return_text=True)) | |||
| response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) | |||
| elif response.status_code == 403: | |||
| scraper = cloudscraper.create_scraper() | |||
| scraper.perform_request = ssrf_proxy.make_request | |||
| response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) | |||
| if response.status_code != 200: | |||
| return "URL returned status code {}.".format(response.status_code) | |||
| # Detect encoding using chardet | |||
| detected_encoding = chardet.detect(response.content) | |||
| encoding = detected_encoding["encoding"] | |||
| if encoding: | |||
| try: | |||
| content = response.content.decode(encoding) | |||
| except (UnicodeDecodeError, TypeError): | |||
| content = response.text | |||
| else: | |||
| content = response.text | |||
| a = extract_using_readabilipy(content) | |||
| if not a["plain_text"] or not a["plain_text"].strip(): | |||
| return "" | |||
| res = FULL_TEMPLATE.format( | |||
| title=a["title"], | |||
| authors=a["byline"], | |||
| publish_date=a["date"], | |||
| top_image="", | |||
| text=a["plain_text"] or "", | |||
| ) | |||
| return res | |||
| def extract_using_readabilipy(html): | |||
| with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html: | |||
| f_html.write(html) | |||
| f_html.close() | |||
| html_path = f_html.name | |||
| # Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file | |||
| article_json_path = html_path + ".json" | |||
| jsdir = os.path.join(find_module_path("readabilipy"), "javascript") | |||
| with chdir(jsdir): | |||
| subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path]) | |||
| # Read output of call to Readability.parse() from JSON file and return as Python dictionary | |||
| input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8")) | |||
| # Deleting files after processing | |||
| os.unlink(article_json_path) | |||
| os.unlink(html_path) | |||
| article_json: dict[str, Any] = { | |||
| "title": None, | |||
| "byline": None, | |||
| "date": None, | |||
| "content": None, | |||
| "plain_content": None, | |||
| "plain_text": None, | |||
| } | |||
| # Populate article fields from readability fields where present | |||
| if input_json: | |||
| if input_json.get("title"): | |||
| article_json["title"] = input_json["title"] | |||
| if input_json.get("byline"): | |||
| article_json["byline"] = input_json["byline"] | |||
| if input_json.get("date"): | |||
| article_json["date"] = input_json["date"] | |||
| if input_json.get("content"): | |||
| article_json["content"] = input_json["content"] | |||
| article_json["plain_content"] = plain_content(article_json["content"], False, False) | |||
| article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"]) | |||
| if input_json.get("textContent"): | |||
| article_json["plain_text"] = input_json["textContent"] | |||
| article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"]) | |||
| return article_json | |||
| def find_module_path(module_name): | |||
| for package_path in site.getsitepackages(): | |||
| potential_path = os.path.join(package_path, module_name) | |||
| if os.path.exists(potential_path): | |||
| return potential_path | |||
| return None | |||
| @contextmanager | |||
| def chdir(path): | |||
| """Change directory in context and return to original on exit""" | |||
| # From https://stackoverflow.com/a/37996581, couldn't find a built-in | |||
| original_path = os.getcwd() | |||
| os.chdir(path) | |||
| try: | |||
| yield | |||
| finally: | |||
| os.chdir(original_path) | |||
| def extract_text_blocks_as_plain_text(paragraph_html): | |||
| # Load article as DOM | |||
| soup = BeautifulSoup(paragraph_html, "html.parser") | |||
| # Select all lists | |||
| list_elements = soup.find_all(["ul", "ol"]) | |||
| # Prefix text in all list items with "* " and make lists paragraphs | |||
| for list_element in list_elements: | |||
| plain_items = "".join( | |||
| list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")])) | |||
| ) | |||
| list_element.string = plain_items | |||
| list_element.name = "p" | |||
| # Select all text blocks | |||
| text_blocks = [s.parent for s in soup.find_all(string=True)] | |||
| text_blocks = [plain_text_leaf_node(block) for block in text_blocks] | |||
| # Drop empty paragraphs | |||
| text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks)) | |||
| return text_blocks | |||
| def plain_text_leaf_node(element): | |||
| # Extract all text, stripped of any child HTML elements and normalize it | |||
| plain_text = normalize_text(element.get_text()) | |||
| if plain_text != "" and element.name == "li": | |||
| plain_text = "* {}, ".format(plain_text) | |||
| if plain_text == "": | |||
| plain_text = None | |||
| if "data-node-index" in element.attrs: | |||
| plain = {"node_index": element["data-node-index"], "text": plain_text} | |||
| else: | |||
| plain = {"text": plain_text} | |||
| return plain | |||
| def plain_content(readability_content, content_digests, node_indexes): | |||
| # Load article as DOM | |||
| soup = BeautifulSoup(readability_content, "html.parser") | |||
| # Make all elements plain | |||
| elements = plain_elements(soup.contents, content_digests, node_indexes) | |||
| if node_indexes: | |||
| # Add node index attributes to nodes | |||
| elements = [add_node_indexes(element) for element in elements] | |||
| # Replace article contents with plain elements | |||
| soup.contents = elements | |||
| return str(soup) | |||
| def plain_elements(elements, content_digests, node_indexes): | |||
| # Get plain content versions of all elements | |||
| elements = [plain_element(element, content_digests, node_indexes) for element in elements] | |||
| if content_digests: | |||
| # Add content digest attribute to nodes | |||
| elements = [add_content_digest(element) for element in elements] | |||
| return elements | |||
| def plain_element(element, content_digests, node_indexes): | |||
| # For lists, we make each item plain text | |||
| if is_leaf(element): | |||
| # For leaf node elements, extract the text content, discarding any HTML tags | |||
| # 1. Get element contents as text | |||
| plain_text = element.get_text() | |||
| # 2. Normalize the extracted text string to a canonical representation | |||
| plain_text = normalize_text(plain_text) | |||
| # 3. Update element content to be plain text | |||
| element.string = plain_text | |||
| elif is_text(element): | |||
| if is_non_printing(element): | |||
| # The simplified HTML may have come from Readability.js so might | |||
| # have non-printing text (e.g. Comment or CData). In this case, we | |||
| # keep the structure, but ensure that the string is empty. | |||
| element = type(element)("") | |||
| else: | |||
| plain_text = element.string | |||
| plain_text = normalize_text(plain_text) | |||
| element = type(element)(plain_text) | |||
| else: | |||
| # If not a leaf node or leaf type call recursively on child nodes, replacing | |||
| element.contents = plain_elements(element.contents, content_digests, node_indexes) | |||
| return element | |||
| def add_node_indexes(element, node_index="0"): | |||
| # Can't add attributes to string types | |||
| if is_text(element): | |||
| return element | |||
| # Add index to current element | |||
| element["data-node-index"] = node_index | |||
| # Add index to child elements | |||
| for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1): | |||
| # Can't add attributes to leaf string types | |||
| child_index = "{stem}.{local}".format(stem=node_index, local=local_idx) | |||
| add_node_indexes(child, node_index=child_index) | |||
| return element | |||
| def normalize_text(text): | |||
| """Normalize unicode and whitespace.""" | |||
| # Normalize unicode first to try and standardize whitespace characters as much as possible before normalizing them | |||
| text = strip_control_characters(text) | |||
| text = normalize_unicode(text) | |||
| text = normalize_whitespace(text) | |||
| return text | |||
| def strip_control_characters(text): | |||
| """Strip out unicode control characters which might break the parsing.""" | |||
| # Unicode control characters | |||
| # [Cc]: Other, Control [includes new lines] | |||
| # [Cf]: Other, Format | |||
| # [Cn]: Other, Not Assigned | |||
| # [Co]: Other, Private Use | |||
| # [Cs]: Other, Surrogate | |||
| control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"} | |||
| retained_chars = ["\t", "\n", "\r", "\f"] | |||
| # Remove non-printing control characters | |||
| return "".join( | |||
| [ | |||
| "" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char | |||
| for char in text | |||
| ] | |||
| ) | |||
| def normalize_unicode(text): | |||
| """Normalize unicode such that things that are visually equivalent map to the same unicode string where possible.""" | |||
| normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC" | |||
| text = unicodedata.normalize(normal_form, text) | |||
| return text | |||
| def normalize_whitespace(text): | |||
| """Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed.""" | |||
| text = regex.sub(r"\s+", " ", text) | |||
| # Remove leading and trailing whitespace | |||
| text = text.strip() | |||
| return text | |||
| def is_leaf(element): | |||
| return element.name in {"p", "li"} | |||
| def is_text(element): | |||
| return isinstance(element, NavigableString) | |||
| def is_non_printing(element): | |||
| return any(isinstance(element, _e) for _e in [Comment, CData]) | |||
| def add_content_digest(element): | |||
| if not is_text(element): | |||
| element["data-content-digest"] = content_digest(element) | |||
| return element | |||
| def content_digest(element): | |||
| digest: Any | |||
| if is_text(element): | |||
| # Hash | |||
| trimmed_string = element.string.strip() | |||
| if trimmed_string == "": | |||
| digest = "" | |||
| else: | |||
| digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest() | |||
| else: | |||
| contents = element.contents | |||
| num_contents = len(contents) | |||
| if num_contents == 0: | |||
| # No hash when no child elements exist | |||
| digest = "" | |||
| elif num_contents == 1: | |||
| # If single child, use digest of child | |||
| digest = content_digest(contents[0]) | |||
| else: | |||
| # Build content digest from the "non-empty" digests of child nodes | |||
| digest = hashlib.sha256() | |||
| child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents])) | |||
| for child in child_digests: | |||
| digest.update(child.encode("utf-8")) | |||
| digest = digest.hexdigest() | |||
| return digest | |||
| def get_image_upload_file_ids(content): | |||
| pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)" | |||
| matches = re.findall(pattern, content) | |||
| image_upload_file_ids = [] | |||
| for match in matches: | |||
| if match[1] == "file-preview": | |||
| content_pattern = r"files/([^/]+)/file-preview" | |||
| else: | |||
| content_pattern = r"files/([^/]+)/image-preview" | |||
| content_match = re.search(content_pattern, match[0]) | |||
| if content_match: | |||
| image_upload_file_id = content_match.group(1) | |||
| image_upload_file_ids.append(image_upload_file_id) | |||
| return image_upload_file_ids | |||
| @@ -0,0 +1 @@ | |||
| {"not_installed": [], "plugin_install_failed": []} | |||