| @@ -56,6 +56,7 @@ DEFAULTS = { | |||
| 'BILLING_ENABLED': 'False', | |||
| 'CAN_REPLACE_LOGO': 'False', | |||
| 'ETL_TYPE': 'dify', | |||
| 'KEYWORD_STORE': 'jieba', | |||
| 'BATCH_UPLOAD_LIMIT': 20 | |||
| } | |||
| @@ -183,7 +184,7 @@ class Config: | |||
| # Currently, only support: qdrant, milvus, zilliz, weaviate | |||
| # ------------------------ | |||
| self.VECTOR_STORE = get_env('VECTOR_STORE') | |||
| self.KEYWORD_STORE = get_env('KEYWORD_STORE') | |||
| # qdrant settings | |||
| self.QDRANT_URL = get_env('QDRANT_URL') | |||
| self.QDRANT_API_KEY = get_env('QDRANT_API_KEY') | |||
| @@ -9,8 +9,9 @@ from werkzeug.exceptions import NotFound | |||
| from controllers.console import api | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.data_loader.loader.notion import NotionLoader | |||
| from core.indexing_runner import IndexingRunner | |||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||
| from core.rag.extractor.notion_extractor import NotionExtractor | |||
| from extensions.ext_database import db | |||
| from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields | |||
| from libs.login import login_required | |||
| @@ -173,14 +174,14 @@ class DataSourceNotionApi(Resource): | |||
| if not data_source_binding: | |||
| raise NotFound('Data source binding not found.') | |||
| loader = NotionLoader( | |||
| notion_access_token=data_source_binding.access_token, | |||
| extractor = NotionExtractor( | |||
| notion_workspace_id=workspace_id, | |||
| notion_obj_id=page_id, | |||
| notion_page_type=page_type | |||
| notion_page_type=page_type, | |||
| notion_access_token=data_source_binding.access_token | |||
| ) | |||
| text_docs = loader.load() | |||
| text_docs = extractor.extract() | |||
| return { | |||
| 'content': "\n".join([doc.page_content for doc in text_docs]) | |||
| }, 200 | |||
| @@ -192,11 +193,30 @@ class DataSourceNotionApi(Resource): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json') | |||
| parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') | |||
| parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | |||
| parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') | |||
| args = parser.parse_args() | |||
| # validate args | |||
| DocumentService.estimate_args_validate(args) | |||
| notion_info_list = args['notion_info_list'] | |||
| extract_settings = [] | |||
| for notion_info in notion_info_list: | |||
| workspace_id = notion_info['workspace_id'] | |||
| for page in notion_info['pages']: | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="notion_import", | |||
| notion_info={ | |||
| "notion_workspace_id": workspace_id, | |||
| "notion_obj_id": page['page_id'], | |||
| "notion_page_type": page['type'] | |||
| }, | |||
| document_model=args['doc_form'] | |||
| ) | |||
| extract_settings.append(extract_setting) | |||
| indexing_runner = IndexingRunner() | |||
| response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule']) | |||
| response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, | |||
| args['process_rule'], args['doc_form'], | |||
| args['doc_language']) | |||
| return response, 200 | |||
| @@ -15,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | |||
| from core.indexing_runner import IndexingRunner | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.provider_manager import ProviderManager | |||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||
| from extensions.ext_database import db | |||
| from fields.app_fields import related_app_list | |||
| from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields | |||
| @@ -178,9 +179,9 @@ class DatasetApi(Resource): | |||
| location='json', store_missing=False, | |||
| type=_validate_description_length) | |||
| parser.add_argument('indexing_technique', type=str, location='json', | |||
| choices=Dataset.INDEXING_TECHNIQUE_LIST, | |||
| nullable=True, | |||
| help='Invalid indexing technique.') | |||
| choices=Dataset.INDEXING_TECHNIQUE_LIST, | |||
| nullable=True, | |||
| help='Invalid indexing technique.') | |||
| parser.add_argument('permission', type=str, location='json', choices=( | |||
| 'only_me', 'all_team_members'), help='Invalid permission.') | |||
| parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.') | |||
| @@ -258,7 +259,7 @@ class DatasetIndexingEstimateApi(Resource): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') | |||
| parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') | |||
| parser.add_argument('indexing_technique', type=str, required=True, | |||
| parser.add_argument('indexing_technique', type=str, required=True, | |||
| choices=Dataset.INDEXING_TECHNIQUE_LIST, | |||
| nullable=True, location='json') | |||
| parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | |||
| @@ -268,6 +269,7 @@ class DatasetIndexingEstimateApi(Resource): | |||
| args = parser.parse_args() | |||
| # validate args | |||
| DocumentService.estimate_args_validate(args) | |||
| extract_settings = [] | |||
| if args['info_list']['data_source_type'] == 'upload_file': | |||
| file_ids = args['info_list']['file_info_list']['file_ids'] | |||
| file_details = db.session.query(UploadFile).filter( | |||
| @@ -278,37 +280,44 @@ class DatasetIndexingEstimateApi(Resource): | |||
| if file_details is None: | |||
| raise NotFound("File not found.") | |||
| indexing_runner = IndexingRunner() | |||
| try: | |||
| response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, | |||
| args['process_rule'], args['doc_form'], | |||
| args['doc_language'], args['dataset_id'], | |||
| args['indexing_technique']) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| "No Embedding Model available. Please configure a valid provider " | |||
| "in the Settings -> Model Provider.") | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| if file_details: | |||
| for file_detail in file_details: | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="upload_file", | |||
| upload_file=file_detail, | |||
| document_model=args['doc_form'] | |||
| ) | |||
| extract_settings.append(extract_setting) | |||
| elif args['info_list']['data_source_type'] == 'notion_import': | |||
| indexing_runner = IndexingRunner() | |||
| try: | |||
| response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, | |||
| args['info_list']['notion_info_list'], | |||
| args['process_rule'], args['doc_form'], | |||
| args['doc_language'], args['dataset_id'], | |||
| args['indexing_technique']) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| "No Embedding Model available. Please configure a valid provider " | |||
| "in the Settings -> Model Provider.") | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| notion_info_list = args['info_list']['notion_info_list'] | |||
| for notion_info in notion_info_list: | |||
| workspace_id = notion_info['workspace_id'] | |||
| for page in notion_info['pages']: | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="notion_import", | |||
| notion_info={ | |||
| "notion_workspace_id": workspace_id, | |||
| "notion_obj_id": page['page_id'], | |||
| "notion_page_type": page['type'] | |||
| }, | |||
| document_model=args['doc_form'] | |||
| ) | |||
| extract_settings.append(extract_setting) | |||
| else: | |||
| raise ValueError('Data source type not support') | |||
| indexing_runner = IndexingRunner() | |||
| try: | |||
| response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, | |||
| args['process_rule'], args['doc_form'], | |||
| args['doc_language'], args['dataset_id'], | |||
| args['indexing_technique']) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| "No Embedding Model available. Please configure a valid provider " | |||
| "in the Settings -> Model Provider.") | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| return response, 200 | |||
| @@ -508,4 +517,3 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>') | |||
| api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') | |||
| api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') | |||
| api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>') | |||
| @@ -32,6 +32,7 @@ from core.indexing_runner import IndexingRunner | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from fields.document_fields import ( | |||
| @@ -95,7 +96,7 @@ class GetProcessRuleApi(Resource): | |||
| req_data = request.args | |||
| document_id = req_data.get('document_id') | |||
| # get default rules | |||
| mode = DocumentService.DEFAULT_RULES['mode'] | |||
| rules = DocumentService.DEFAULT_RULES['rules'] | |||
| @@ -362,12 +363,18 @@ class DocumentIndexingEstimateApi(DocumentResource): | |||
| if not file: | |||
| raise NotFound('File not found.') | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="upload_file", | |||
| upload_file=file, | |||
| document_model=document.doc_form | |||
| ) | |||
| indexing_runner = IndexingRunner() | |||
| try: | |||
| response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file], | |||
| data_process_rule_dict, None, | |||
| 'English', dataset_id) | |||
| response = indexing_runner.indexing_estimate(current_user.current_tenant_id, [extract_setting], | |||
| data_process_rule_dict, document.doc_form, | |||
| 'English', dataset_id) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| "No Embedding Model available. Please configure a valid provider " | |||
| @@ -402,6 +409,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| data_process_rule = documents[0].dataset_process_rule | |||
| data_process_rule_dict = data_process_rule.to_dict() | |||
| info_list = [] | |||
| extract_settings = [] | |||
| for document in documents: | |||
| if document.indexing_status in ['completed', 'error']: | |||
| raise DocumentAlreadyFinishedError() | |||
| @@ -424,42 +432,48 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| } | |||
| info_list.append(notion_info) | |||
| if dataset.data_source_type == 'upload_file': | |||
| file_details = db.session.query(UploadFile).filter( | |||
| UploadFile.tenant_id == current_user.current_tenant_id, | |||
| UploadFile.id.in_(info_list) | |||
| ).all() | |||
| if document.data_source_type == 'upload_file': | |||
| file_id = data_source_info['upload_file_id'] | |||
| file_detail = db.session.query(UploadFile).filter( | |||
| UploadFile.tenant_id == current_user.current_tenant_id, | |||
| UploadFile.id == file_id | |||
| ).first() | |||
| if file_details is None: | |||
| raise NotFound("File not found.") | |||
| if file_detail is None: | |||
| raise NotFound("File not found.") | |||
| indexing_runner = IndexingRunner() | |||
| try: | |||
| response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, | |||
| data_process_rule_dict, None, | |||
| 'English', dataset_id) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| "No Embedding Model available. Please configure a valid provider " | |||
| "in the Settings -> Model Provider.") | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| elif dataset.data_source_type == 'notion_import': | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="upload_file", | |||
| upload_file=file_detail, | |||
| document_model=document.doc_form | |||
| ) | |||
| extract_settings.append(extract_setting) | |||
| elif document.data_source_type == 'notion_import': | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="notion_import", | |||
| notion_info={ | |||
| "notion_workspace_id": data_source_info['notion_workspace_id'], | |||
| "notion_obj_id": data_source_info['notion_page_id'], | |||
| "notion_page_type": data_source_info['type'] | |||
| }, | |||
| document_model=document.doc_form | |||
| ) | |||
| extract_settings.append(extract_setting) | |||
| else: | |||
| raise ValueError('Data source type not support') | |||
| indexing_runner = IndexingRunner() | |||
| try: | |||
| response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, | |||
| info_list, | |||
| data_process_rule_dict, | |||
| None, 'English', dataset_id) | |||
| response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, | |||
| data_process_rule_dict, document.doc_form, | |||
| 'English', dataset_id) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| "No Embedding Model available. Please configure a valid provider " | |||
| "in the Settings -> Model Provider.") | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| else: | |||
| raise ValueError('Data source type not support') | |||
| return response | |||
| @@ -1,107 +0,0 @@ | |||
| import tempfile | |||
| from pathlib import Path | |||
| from typing import Optional, Union | |||
| import requests | |||
| from flask import current_app | |||
| from langchain.document_loaders import Docx2txtLoader, TextLoader | |||
| from langchain.schema import Document | |||
| from core.data_loader.loader.csv_loader import CSVLoader | |||
| from core.data_loader.loader.excel import ExcelLoader | |||
| from core.data_loader.loader.html import HTMLLoader | |||
| from core.data_loader.loader.markdown import MarkdownLoader | |||
| from core.data_loader.loader.pdf import PdfLoader | |||
| from core.data_loader.loader.unstructured.unstructured_eml import UnstructuredEmailLoader | |||
| from core.data_loader.loader.unstructured.unstructured_markdown import UnstructuredMarkdownLoader | |||
| from core.data_loader.loader.unstructured.unstructured_msg import UnstructuredMsgLoader | |||
| from core.data_loader.loader.unstructured.unstructured_ppt import UnstructuredPPTLoader | |||
| from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredPPTXLoader | |||
| from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader | |||
| from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader | |||
| from extensions.ext_storage import storage | |||
| from models.model import UploadFile | |||
| SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain'] | |||
| USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" | |||
| class FileExtractor: | |||
| @classmethod | |||
| def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[list[Document], str]: | |||
| with tempfile.TemporaryDirectory() as temp_dir: | |||
| suffix = Path(upload_file.key).suffix | |||
| file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" | |||
| storage.download(upload_file.key, file_path) | |||
| return cls.load_from_file(file_path, return_text, upload_file, is_automatic) | |||
| @classmethod | |||
| def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]: | |||
| response = requests.get(url, headers={ | |||
| "User-Agent": USER_AGENT | |||
| }) | |||
| with tempfile.TemporaryDirectory() as temp_dir: | |||
| suffix = Path(url).suffix | |||
| file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" | |||
| with open(file_path, 'wb') as file: | |||
| file.write(response.content) | |||
| return cls.load_from_file(file_path, return_text) | |||
| @classmethod | |||
| def load_from_file(cls, file_path: str, return_text: bool = False, | |||
| upload_file: Optional[UploadFile] = None, | |||
| is_automatic: bool = False) -> Union[list[Document], str]: | |||
| input_file = Path(file_path) | |||
| delimiter = '\n' | |||
| file_extension = input_file.suffix.lower() | |||
| etl_type = current_app.config['ETL_TYPE'] | |||
| unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL'] | |||
| if etl_type == 'Unstructured': | |||
| if file_extension == '.xlsx': | |||
| loader = ExcelLoader(file_path) | |||
| elif file_extension == '.pdf': | |||
| loader = PdfLoader(file_path, upload_file=upload_file) | |||
| elif file_extension in ['.md', '.markdown']: | |||
| loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url) if is_automatic \ | |||
| else MarkdownLoader(file_path, autodetect_encoding=True) | |||
| elif file_extension in ['.htm', '.html']: | |||
| loader = HTMLLoader(file_path) | |||
| elif file_extension in ['.docx']: | |||
| loader = Docx2txtLoader(file_path) | |||
| elif file_extension == '.csv': | |||
| loader = CSVLoader(file_path, autodetect_encoding=True) | |||
| elif file_extension == '.msg': | |||
| loader = UnstructuredMsgLoader(file_path, unstructured_api_url) | |||
| elif file_extension == '.eml': | |||
| loader = UnstructuredEmailLoader(file_path, unstructured_api_url) | |||
| elif file_extension == '.ppt': | |||
| loader = UnstructuredPPTLoader(file_path, unstructured_api_url) | |||
| elif file_extension == '.pptx': | |||
| loader = UnstructuredPPTXLoader(file_path, unstructured_api_url) | |||
| elif file_extension == '.xml': | |||
| loader = UnstructuredXmlLoader(file_path, unstructured_api_url) | |||
| else: | |||
| # txt | |||
| loader = UnstructuredTextLoader(file_path, unstructured_api_url) if is_automatic \ | |||
| else TextLoader(file_path, autodetect_encoding=True) | |||
| else: | |||
| if file_extension == '.xlsx': | |||
| loader = ExcelLoader(file_path) | |||
| elif file_extension == '.pdf': | |||
| loader = PdfLoader(file_path, upload_file=upload_file) | |||
| elif file_extension in ['.md', '.markdown']: | |||
| loader = MarkdownLoader(file_path, autodetect_encoding=True) | |||
| elif file_extension in ['.htm', '.html']: | |||
| loader = HTMLLoader(file_path) | |||
| elif file_extension in ['.docx']: | |||
| loader = Docx2txtLoader(file_path) | |||
| elif file_extension == '.csv': | |||
| loader = CSVLoader(file_path, autodetect_encoding=True) | |||
| else: | |||
| # txt | |||
| loader = TextLoader(file_path, autodetect_encoding=True) | |||
| return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load() | |||
| @@ -1,34 +0,0 @@ | |||
| import logging | |||
| from bs4 import BeautifulSoup | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.schema import Document | |||
| logger = logging.getLogger(__name__) | |||
| class HTMLLoader(BaseLoader): | |||
| """Load html files. | |||
| Args: | |||
| file_path: Path to the file to load. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| file_path: str | |||
| ): | |||
| """Initialize with file path.""" | |||
| self._file_path = file_path | |||
| def load(self) -> list[Document]: | |||
| return [Document(page_content=self._load_as_text())] | |||
| def _load_as_text(self) -> str: | |||
| with open(self._file_path, "rb") as fp: | |||
| soup = BeautifulSoup(fp, 'html.parser') | |||
| text = soup.get_text() | |||
| text = text.strip() if text else '' | |||
| return text | |||
| @@ -1,55 +0,0 @@ | |||
| import logging | |||
| from typing import Optional | |||
| from langchain.document_loaders import PyPDFium2Loader | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.schema import Document | |||
| from extensions.ext_storage import storage | |||
| from models.model import UploadFile | |||
| logger = logging.getLogger(__name__) | |||
| class PdfLoader(BaseLoader): | |||
| """Load pdf files. | |||
| Args: | |||
| file_path: Path to the file to load. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| file_path: str, | |||
| upload_file: Optional[UploadFile] = None | |||
| ): | |||
| """Initialize with file path.""" | |||
| self._file_path = file_path | |||
| self._upload_file = upload_file | |||
| def load(self) -> list[Document]: | |||
| plaintext_file_key = '' | |||
| plaintext_file_exists = False | |||
| if self._upload_file: | |||
| if self._upload_file.hash: | |||
| plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \ | |||
| + self._upload_file.hash + '.0625.plaintext' | |||
| try: | |||
| text = storage.load(plaintext_file_key).decode('utf-8') | |||
| plaintext_file_exists = True | |||
| return [Document(page_content=text)] | |||
| except FileNotFoundError: | |||
| pass | |||
| documents = PyPDFium2Loader(file_path=self._file_path).load() | |||
| text_list = [] | |||
| for document in documents: | |||
| text_list.append(document.page_content) | |||
| text = "\n\n".join(text_list) | |||
| # save plaintext file for caching | |||
| if not plaintext_file_exists and plaintext_file_key: | |||
| storage.save(plaintext_file_key, text.encode('utf-8')) | |||
| return documents | |||
| @@ -1,12 +1,12 @@ | |||
| from collections.abc import Sequence | |||
| from typing import Any, Optional, cast | |||
| from langchain.schema import Document | |||
| from sqlalchemy import func | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | |||
| from core.rag.models.document import Document | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, DocumentSegment | |||
| @@ -1,13 +1,8 @@ | |||
| import logging | |||
| from typing import Optional | |||
| from flask import current_app | |||
| from core.embedding.cached_embedding import CacheEmbedding | |||
| from core.entities.application_entities import InvokeFrom | |||
| from core.index.vector_index.vector_index import VectorIndex | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.rag.datasource.vdb.vector_factory import Vector | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset | |||
| from models.model import App, AppAnnotationSetting, Message, MessageAnnotation | |||
| @@ -45,17 +40,6 @@ class AnnotationReplyFeature: | |||
| embedding_provider_name = collection_binding_detail.provider_name | |||
| embedding_model_name = collection_binding_detail.model_name | |||
| model_manager = ModelManager() | |||
| model_instance = model_manager.get_model_instance( | |||
| tenant_id=app_record.tenant_id, | |||
| provider=embedding_provider_name, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=embedding_model_name | |||
| ) | |||
| # get embedding model | |||
| embeddings = CacheEmbedding(model_instance) | |||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( | |||
| embedding_provider_name, | |||
| embedding_model_name, | |||
| @@ -71,22 +55,14 @@ class AnnotationReplyFeature: | |||
| collection_binding_id=dataset_collection_binding.id | |||
| ) | |||
| vector_index = VectorIndex( | |||
| dataset=dataset, | |||
| config=current_app.config, | |||
| embeddings=embeddings, | |||
| attributes=['doc_id', 'annotation_id', 'app_id'] | |||
| ) | |||
| vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) | |||
| documents = vector_index.search( | |||
| documents = vector.search_by_vector( | |||
| query=query, | |||
| search_type='similarity_score_threshold', | |||
| search_kwargs={ | |||
| 'k': 1, | |||
| 'score_threshold': score_threshold, | |||
| 'filter': { | |||
| 'group_id': [dataset.id] | |||
| } | |||
| k=1, | |||
| score_threshold=score_threshold, | |||
| filter={ | |||
| 'group_id': [dataset.id] | |||
| } | |||
| ) | |||
| @@ -1,51 +0,0 @@ | |||
| from flask import current_app | |||
| from langchain.embeddings import OpenAIEmbeddings | |||
| from core.embedding.cached_embedding import CacheEmbedding | |||
| from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex | |||
| from core.index.vector_index.vector_index import VectorIndex | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from models.dataset import Dataset | |||
| class IndexBuilder: | |||
| @classmethod | |||
| def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False): | |||
| if indexing_technique == "high_quality": | |||
| if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality': | |||
| return None | |||
| model_manager = ModelManager() | |||
| embedding_model = model_manager.get_model_instance( | |||
| tenant_id=dataset.tenant_id, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| provider=dataset.embedding_model_provider, | |||
| model=dataset.embedding_model | |||
| ) | |||
| embeddings = CacheEmbedding(embedding_model) | |||
| return VectorIndex( | |||
| dataset=dataset, | |||
| config=current_app.config, | |||
| embeddings=embeddings | |||
| ) | |||
| elif indexing_technique == "economy": | |||
| return KeywordTableIndex( | |||
| dataset=dataset, | |||
| config=KeywordTableConfig( | |||
| max_keywords_per_chunk=10 | |||
| ) | |||
| ) | |||
| else: | |||
| raise ValueError('Unknown indexing technique') | |||
| @classmethod | |||
| def get_default_high_quality_index(cls, dataset: Dataset): | |||
| embeddings = OpenAIEmbeddings(openai_api_key=' ') | |||
| return VectorIndex( | |||
| dataset=dataset, | |||
| config=current_app.config, | |||
| embeddings=embeddings | |||
| ) | |||
| @@ -1,305 +0,0 @@ | |||
| import json | |||
| import logging | |||
| from abc import abstractmethod | |||
| from typing import Any, cast | |||
| from langchain.embeddings.base import Embeddings | |||
| from langchain.schema import BaseRetriever, Document | |||
| from langchain.vectorstores import VectorStore | |||
| from core.index.base import BaseIndex | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment | |||
| from models.dataset import Document as DatasetDocument | |||
| class BaseVectorIndex(BaseIndex): | |||
| def __init__(self, dataset: Dataset, embeddings: Embeddings): | |||
| super().__init__(dataset) | |||
| self._embeddings = embeddings | |||
| self._vector_store = None | |||
| def get_type(self) -> str: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def get_index_name(self, dataset: Dataset) -> str: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def to_index_struct(self) -> dict: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def _get_vector_store(self) -> VectorStore: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def _get_vector_store_class(self) -> type: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def search_by_full_text_index( | |||
| self, query: str, | |||
| **kwargs: Any | |||
| ) -> list[Document]: | |||
| raise NotImplementedError | |||
| def search( | |||
| self, query: str, | |||
| **kwargs: Any | |||
| ) -> list[Document]: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity' | |||
| search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {} | |||
| if search_type == 'similarity_score_threshold': | |||
| score_threshold = search_kwargs.get("score_threshold") | |||
| if (score_threshold is None) or (not isinstance(score_threshold, float)): | |||
| search_kwargs['score_threshold'] = .0 | |||
| docs_with_similarity = vector_store.similarity_search_with_relevance_scores( | |||
| query, **search_kwargs | |||
| ) | |||
| docs = [] | |||
| for doc, similarity in docs_with_similarity: | |||
| doc.metadata['score'] = similarity | |||
| docs.append(doc) | |||
| return docs | |||
| # similarity k | |||
| # mmr k, fetch_k, lambda_mult | |||
| # similarity_score_threshold k | |||
| return vector_store.as_retriever( | |||
| search_type=search_type, | |||
| search_kwargs=search_kwargs | |||
| ).get_relevant_documents(query) | |||
| def get_retriever(self, **kwargs: Any) -> BaseRetriever: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| return vector_store.as_retriever(**kwargs) | |||
| def add_texts(self, texts: list[Document], **kwargs): | |||
| if self._is_origin(): | |||
| self.recreate_dataset(self.dataset) | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| if kwargs.get('duplicate_check', False): | |||
| texts = self._filter_duplicate_texts(texts) | |||
| uuids = self._get_uuids(texts) | |||
| vector_store.add_documents(texts, uuids=uuids) | |||
| def text_exists(self, id: str) -> bool: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| return vector_store.text_exists(id) | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| if self._is_origin(): | |||
| self.recreate_dataset(self.dataset) | |||
| return | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| for node_id in ids: | |||
| vector_store.del_text(node_id) | |||
| def delete_by_group_id(self, group_id: str) -> None: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| if self.dataset.collection_binding_id: | |||
| vector_store.delete_by_group_id(group_id) | |||
| else: | |||
| vector_store.delete() | |||
| def delete(self) -> None: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| vector_store.delete() | |||
| def _is_origin(self): | |||
| return False | |||
| def recreate_dataset(self, dataset: Dataset): | |||
| logging.info(f"Recreating dataset {dataset.id}") | |||
| try: | |||
| self.delete() | |||
| except Exception as e: | |||
| raise e | |||
| dataset_documents = db.session.query(DatasetDocument).filter( | |||
| DatasetDocument.dataset_id == dataset.id, | |||
| DatasetDocument.indexing_status == 'completed', | |||
| DatasetDocument.enabled == True, | |||
| DatasetDocument.archived == False, | |||
| ).all() | |||
| documents = [] | |||
| for dataset_document in dataset_documents: | |||
| segments = db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.document_id == dataset_document.id, | |||
| DocumentSegment.status == 'completed', | |||
| DocumentSegment.enabled == True | |||
| ).all() | |||
| for segment in segments: | |||
| document = Document( | |||
| page_content=segment.content, | |||
| metadata={ | |||
| "doc_id": segment.index_node_id, | |||
| "doc_hash": segment.index_node_hash, | |||
| "document_id": segment.document_id, | |||
| "dataset_id": segment.dataset_id, | |||
| } | |||
| ) | |||
| documents.append(document) | |||
| origin_index_struct = self.dataset.index_struct[:] | |||
| self.dataset.index_struct = None | |||
| if documents: | |||
| try: | |||
| self.create(documents) | |||
| except Exception as e: | |||
| self.dataset.index_struct = origin_index_struct | |||
| raise e | |||
| dataset.index_struct = json.dumps(self.to_index_struct()) | |||
| db.session.commit() | |||
| self.dataset = dataset | |||
| logging.info(f"Dataset {dataset.id} recreate successfully.") | |||
| def create_qdrant_dataset(self, dataset: Dataset): | |||
| logging.info(f"create_qdrant_dataset {dataset.id}") | |||
| try: | |||
| self.delete() | |||
| except Exception as e: | |||
| raise e | |||
| dataset_documents = db.session.query(DatasetDocument).filter( | |||
| DatasetDocument.dataset_id == dataset.id, | |||
| DatasetDocument.indexing_status == 'completed', | |||
| DatasetDocument.enabled == True, | |||
| DatasetDocument.archived == False, | |||
| ).all() | |||
| documents = [] | |||
| for dataset_document in dataset_documents: | |||
| segments = db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.document_id == dataset_document.id, | |||
| DocumentSegment.status == 'completed', | |||
| DocumentSegment.enabled == True | |||
| ).all() | |||
| for segment in segments: | |||
| document = Document( | |||
| page_content=segment.content, | |||
| metadata={ | |||
| "doc_id": segment.index_node_id, | |||
| "doc_hash": segment.index_node_hash, | |||
| "document_id": segment.document_id, | |||
| "dataset_id": segment.dataset_id, | |||
| } | |||
| ) | |||
| documents.append(document) | |||
| if documents: | |||
| try: | |||
| self.create(documents) | |||
| except Exception as e: | |||
| raise e | |||
| logging.info(f"Dataset {dataset.id} recreate successfully.") | |||
| def update_qdrant_dataset(self, dataset: Dataset): | |||
| logging.info(f"update_qdrant_dataset {dataset.id}") | |||
| segment = db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.dataset_id == dataset.id, | |||
| DocumentSegment.status == 'completed', | |||
| DocumentSegment.enabled == True | |||
| ).first() | |||
| if segment: | |||
| try: | |||
| exist = self.text_exists(segment.index_node_id) | |||
| if exist: | |||
| index_struct = { | |||
| "type": 'qdrant', | |||
| "vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']} | |||
| } | |||
| dataset.index_struct = json.dumps(index_struct) | |||
| db.session.commit() | |||
| except Exception as e: | |||
| raise e | |||
| logging.info(f"Dataset {dataset.id} recreate successfully.") | |||
| def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding): | |||
| logging.info(f"restore dataset in_one,_dataset {dataset.id}") | |||
| dataset_documents = db.session.query(DatasetDocument).filter( | |||
| DatasetDocument.dataset_id == dataset.id, | |||
| DatasetDocument.indexing_status == 'completed', | |||
| DatasetDocument.enabled == True, | |||
| DatasetDocument.archived == False, | |||
| ).all() | |||
| documents = [] | |||
| for dataset_document in dataset_documents: | |||
| segments = db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.document_id == dataset_document.id, | |||
| DocumentSegment.status == 'completed', | |||
| DocumentSegment.enabled == True | |||
| ).all() | |||
| for segment in segments: | |||
| document = Document( | |||
| page_content=segment.content, | |||
| metadata={ | |||
| "doc_id": segment.index_node_id, | |||
| "doc_hash": segment.index_node_hash, | |||
| "document_id": segment.document_id, | |||
| "dataset_id": segment.dataset_id, | |||
| } | |||
| ) | |||
| documents.append(document) | |||
| if documents: | |||
| try: | |||
| self.add_texts(documents) | |||
| except Exception as e: | |||
| raise e | |||
| logging.info(f"Dataset {dataset.id} recreate successfully.") | |||
| def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding): | |||
| logging.info(f"delete original collection: {dataset.id}") | |||
| self.delete() | |||
| dataset.collection_binding_id = dataset_collection_binding.id | |||
| db.session.add(dataset) | |||
| db.session.commit() | |||
| logging.info(f"Dataset {dataset.id} recreate successfully.") | |||
| @@ -1,165 +0,0 @@ | |||
| from typing import Any, cast | |||
| from langchain.embeddings.base import Embeddings | |||
| from langchain.schema import Document | |||
| from langchain.vectorstores import VectorStore | |||
| from pydantic import BaseModel, root_validator | |||
| from core.index.base import BaseIndex | |||
| from core.index.vector_index.base import BaseVectorIndex | |||
| from core.vector_store.milvus_vector_store import MilvusVectorStore | |||
| from models.dataset import Dataset | |||
| class MilvusConfig(BaseModel): | |||
| host: str | |||
| port: int | |||
| user: str | |||
| password: str | |||
| secure: bool = False | |||
| batch_size: int = 100 | |||
| @root_validator() | |||
| def validate_config(cls, values: dict) -> dict: | |||
| if not values['host']: | |||
| raise ValueError("config MILVUS_HOST is required") | |||
| if not values['port']: | |||
| raise ValueError("config MILVUS_PORT is required") | |||
| if not values['user']: | |||
| raise ValueError("config MILVUS_USER is required") | |||
| if not values['password']: | |||
| raise ValueError("config MILVUS_PASSWORD is required") | |||
| return values | |||
| def to_milvus_params(self): | |||
| return { | |||
| 'host': self.host, | |||
| 'port': self.port, | |||
| 'user': self.user, | |||
| 'password': self.password, | |||
| 'secure': self.secure | |||
| } | |||
| class MilvusVectorIndex(BaseVectorIndex): | |||
| def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings): | |||
| super().__init__(dataset, embeddings) | |||
| self._client_config = config | |||
| def get_type(self) -> str: | |||
| return 'milvus' | |||
| def get_index_name(self, dataset: Dataset) -> str: | |||
| if self.dataset.index_struct_dict: | |||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||
| if not class_prefix.endswith('_Node'): | |||
| # original class_prefix | |||
| class_prefix += '_Node' | |||
| return class_prefix | |||
| dataset_id = dataset.id | |||
| return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' | |||
| def to_index_struct(self) -> dict: | |||
| return { | |||
| "type": self.get_type(), | |||
| "vector_store": {"class_prefix": self.get_index_name(self.dataset)} | |||
| } | |||
| def create(self, texts: list[Document], **kwargs) -> BaseIndex: | |||
| uuids = self._get_uuids(texts) | |||
| index_params = { | |||
| 'metric_type': 'IP', | |||
| 'index_type': "HNSW", | |||
| 'params': {"M": 8, "efConstruction": 64} | |||
| } | |||
| self._vector_store = MilvusVectorStore.from_documents( | |||
| texts, | |||
| self._embeddings, | |||
| collection_name=self.get_index_name(self.dataset), | |||
| connection_args=self._client_config.to_milvus_params(), | |||
| index_params=index_params | |||
| ) | |||
| return self | |||
| def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: | |||
| uuids = self._get_uuids(texts) | |||
| self._vector_store = MilvusVectorStore.from_documents( | |||
| texts, | |||
| self._embeddings, | |||
| collection_name=collection_name, | |||
| ids=uuids, | |||
| content_payload_key='page_content' | |||
| ) | |||
| return self | |||
| def _get_vector_store(self) -> VectorStore: | |||
| """Only for created index.""" | |||
| if self._vector_store: | |||
| return self._vector_store | |||
| return MilvusVectorStore( | |||
| collection_name=self.get_index_name(self.dataset), | |||
| embedding_function=self._embeddings, | |||
| connection_args=self._client_config.to_milvus_params() | |||
| ) | |||
| def _get_vector_store_class(self) -> type: | |||
| return MilvusVectorStore | |||
| def delete_by_document_id(self, document_id: str): | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| ids = vector_store.get_ids_by_document_id(document_id) | |||
| if ids: | |||
| vector_store.del_texts({ | |||
| 'filter': f'id in {ids}' | |||
| }) | |||
| def delete_by_metadata_field(self, key: str, value: str): | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| ids = vector_store.get_ids_by_metadata_field(key, value) | |||
| if ids: | |||
| vector_store.del_texts({ | |||
| 'filter': f'id in {ids}' | |||
| }) | |||
| def delete_by_ids(self, doc_ids: list[str]) -> None: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| ids = vector_store.get_ids_by_doc_ids(doc_ids) | |||
| vector_store.del_texts({ | |||
| 'filter': f' id in {ids}' | |||
| }) | |||
| def delete_by_group_id(self, group_id: str) -> None: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| vector_store.delete() | |||
| def delete(self) -> None: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| from qdrant_client.http import models | |||
| vector_store.del_texts(models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key="group_id", | |||
| match=models.MatchValue(value=self.dataset.id), | |||
| ), | |||
| ], | |||
| )) | |||
| def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: | |||
| # milvus/zilliz doesn't support bm25 search | |||
| return [] | |||
| @@ -1,229 +0,0 @@ | |||
| import os | |||
| from typing import Any, Optional, cast | |||
| import qdrant_client | |||
| from langchain.embeddings.base import Embeddings | |||
| from langchain.schema import Document | |||
| from langchain.vectorstores import VectorStore | |||
| from pydantic import BaseModel | |||
| from qdrant_client.http.models import HnswConfigDiff | |||
| from core.index.base import BaseIndex | |||
| from core.index.vector_index.base import BaseVectorIndex | |||
| from core.vector_store.qdrant_vector_store import QdrantVectorStore | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, DatasetCollectionBinding | |||
| class QdrantConfig(BaseModel): | |||
| endpoint: str | |||
| api_key: Optional[str] | |||
| timeout: float = 20 | |||
| root_path: Optional[str] | |||
| def to_qdrant_params(self): | |||
| if self.endpoint and self.endpoint.startswith('path:'): | |||
| path = self.endpoint.replace('path:', '') | |||
| if not os.path.isabs(path): | |||
| path = os.path.join(self.root_path, path) | |||
| return { | |||
| 'path': path | |||
| } | |||
| else: | |||
| return { | |||
| 'url': self.endpoint, | |||
| 'api_key': self.api_key, | |||
| 'timeout': self.timeout | |||
| } | |||
| class QdrantVectorIndex(BaseVectorIndex): | |||
| def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings): | |||
| super().__init__(dataset, embeddings) | |||
| self._client_config = config | |||
| def get_type(self) -> str: | |||
| return 'qdrant' | |||
| def get_index_name(self, dataset: Dataset) -> str: | |||
| if dataset.collection_binding_id: | |||
| dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ | |||
| filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ | |||
| one_or_none() | |||
| if dataset_collection_binding: | |||
| return dataset_collection_binding.collection_name | |||
| else: | |||
| raise ValueError('Dataset Collection Bindings is not exist!') | |||
| else: | |||
| if self.dataset.index_struct_dict: | |||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||
| return class_prefix | |||
| dataset_id = dataset.id | |||
| return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' | |||
| def to_index_struct(self) -> dict: | |||
| return { | |||
| "type": self.get_type(), | |||
| "vector_store": {"class_prefix": self.get_index_name(self.dataset)} | |||
| } | |||
| def create(self, texts: list[Document], **kwargs) -> BaseIndex: | |||
| uuids = self._get_uuids(texts) | |||
| self._vector_store = QdrantVectorStore.from_documents( | |||
| texts, | |||
| self._embeddings, | |||
| collection_name=self.get_index_name(self.dataset), | |||
| ids=uuids, | |||
| content_payload_key='page_content', | |||
| group_id=self.dataset.id, | |||
| group_payload_key='group_id', | |||
| hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, | |||
| max_indexing_threads=0, on_disk=False), | |||
| **self._client_config.to_qdrant_params() | |||
| ) | |||
| return self | |||
| def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: | |||
| uuids = self._get_uuids(texts) | |||
| self._vector_store = QdrantVectorStore.from_documents( | |||
| texts, | |||
| self._embeddings, | |||
| collection_name=collection_name, | |||
| ids=uuids, | |||
| content_payload_key='page_content', | |||
| group_id=self.dataset.id, | |||
| group_payload_key='group_id', | |||
| hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, | |||
| max_indexing_threads=0, on_disk=False), | |||
| **self._client_config.to_qdrant_params() | |||
| ) | |||
| return self | |||
| def _get_vector_store(self) -> VectorStore: | |||
| """Only for created index.""" | |||
| if self._vector_store: | |||
| return self._vector_store | |||
| attributes = ['doc_id', 'dataset_id', 'document_id'] | |||
| client = qdrant_client.QdrantClient( | |||
| **self._client_config.to_qdrant_params() | |||
| ) | |||
| return QdrantVectorStore( | |||
| client=client, | |||
| collection_name=self.get_index_name(self.dataset), | |||
| embeddings=self._embeddings, | |||
| content_payload_key='page_content', | |||
| group_id=self.dataset.id, | |||
| group_payload_key='group_id' | |||
| ) | |||
| def _get_vector_store_class(self) -> type: | |||
| return QdrantVectorStore | |||
| def delete_by_document_id(self, document_id: str): | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| from qdrant_client.http import models | |||
| vector_store.del_texts(models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key="metadata.document_id", | |||
| match=models.MatchValue(value=document_id), | |||
| ), | |||
| ], | |||
| )) | |||
| def delete_by_metadata_field(self, key: str, value: str): | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| from qdrant_client.http import models | |||
| vector_store.del_texts(models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key=f"metadata.{key}", | |||
| match=models.MatchValue(value=value), | |||
| ), | |||
| ], | |||
| )) | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| from qdrant_client.http import models | |||
| for node_id in ids: | |||
| vector_store.del_texts(models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key="metadata.doc_id", | |||
| match=models.MatchValue(value=node_id), | |||
| ), | |||
| ], | |||
| )) | |||
| def delete_by_group_id(self, group_id: str) -> None: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| from qdrant_client.http import models | |||
| vector_store.del_texts(models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key="group_id", | |||
| match=models.MatchValue(value=group_id), | |||
| ), | |||
| ], | |||
| )) | |||
| def delete(self) -> None: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| from qdrant_client.http import models | |||
| vector_store.del_texts(models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key="group_id", | |||
| match=models.MatchValue(value=self.dataset.id), | |||
| ), | |||
| ], | |||
| )) | |||
| def _is_origin(self): | |||
| if self.dataset.index_struct_dict: | |||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||
| if not class_prefix.endswith('_Node'): | |||
| # original class_prefix | |||
| return True | |||
| return False | |||
| def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| from qdrant_client.http import models | |||
| return vector_store.similarity_search_by_bm25(models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key="group_id", | |||
| match=models.MatchValue(value=self.dataset.id), | |||
| ), | |||
| models.FieldCondition( | |||
| key="page_content", | |||
| match=models.MatchText(text=query), | |||
| ) | |||
| ], | |||
| ), kwargs.get('top_k', 2)) | |||
| @@ -1,90 +0,0 @@ | |||
| import json | |||
| from flask import current_app | |||
| from langchain.embeddings.base import Embeddings | |||
| from core.index.vector_index.base import BaseVectorIndex | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, Document | |||
| class VectorIndex: | |||
| def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings, | |||
| attributes: list = None): | |||
| if attributes is None: | |||
| attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] | |||
| self._dataset = dataset | |||
| self._embeddings = embeddings | |||
| self._vector_index = self._init_vector_index(dataset, config, embeddings, attributes) | |||
| self._attributes = attributes | |||
| def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings, | |||
| attributes: list) -> BaseVectorIndex: | |||
| vector_type = config.get('VECTOR_STORE') | |||
| if self._dataset.index_struct_dict: | |||
| vector_type = self._dataset.index_struct_dict['type'] | |||
| if not vector_type: | |||
| raise ValueError("Vector store must be specified.") | |||
| if vector_type == "weaviate": | |||
| from core.index.vector_index.weaviate_vector_index import WeaviateConfig, WeaviateVectorIndex | |||
| return WeaviateVectorIndex( | |||
| dataset=dataset, | |||
| config=WeaviateConfig( | |||
| endpoint=config.get('WEAVIATE_ENDPOINT'), | |||
| api_key=config.get('WEAVIATE_API_KEY'), | |||
| batch_size=int(config.get('WEAVIATE_BATCH_SIZE')) | |||
| ), | |||
| embeddings=embeddings, | |||
| attributes=attributes | |||
| ) | |||
| elif vector_type == "qdrant": | |||
| from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex | |||
| return QdrantVectorIndex( | |||
| dataset=dataset, | |||
| config=QdrantConfig( | |||
| endpoint=config.get('QDRANT_URL'), | |||
| api_key=config.get('QDRANT_API_KEY'), | |||
| root_path=current_app.root_path, | |||
| timeout=config.get('QDRANT_CLIENT_TIMEOUT') | |||
| ), | |||
| embeddings=embeddings | |||
| ) | |||
| elif vector_type == "milvus": | |||
| from core.index.vector_index.milvus_vector_index import MilvusConfig, MilvusVectorIndex | |||
| return MilvusVectorIndex( | |||
| dataset=dataset, | |||
| config=MilvusConfig( | |||
| host=config.get('MILVUS_HOST'), | |||
| port=config.get('MILVUS_PORT'), | |||
| user=config.get('MILVUS_USER'), | |||
| password=config.get('MILVUS_PASSWORD'), | |||
| secure=config.get('MILVUS_SECURE'), | |||
| ), | |||
| embeddings=embeddings | |||
| ) | |||
| else: | |||
| raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") | |||
| def add_texts(self, texts: list[Document], **kwargs): | |||
| if not self._dataset.index_struct_dict: | |||
| self._vector_index.create(texts, **kwargs) | |||
| self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct()) | |||
| db.session.commit() | |||
| return | |||
| self._vector_index.add_texts(texts, **kwargs) | |||
| def __getattr__(self, name): | |||
| if self._vector_index is not None: | |||
| method = getattr(self._vector_index, name) | |||
| if callable(method): | |||
| return method | |||
| raise AttributeError(f"'VectorIndex' object has no attribute '{name}'") | |||
| @@ -1,179 +0,0 @@ | |||
| from typing import Any, Optional, cast | |||
| import requests | |||
| import weaviate | |||
| from langchain.embeddings.base import Embeddings | |||
| from langchain.schema import Document | |||
| from langchain.vectorstores import VectorStore | |||
| from pydantic import BaseModel, root_validator | |||
| from core.index.base import BaseIndex | |||
| from core.index.vector_index.base import BaseVectorIndex | |||
| from core.vector_store.weaviate_vector_store import WeaviateVectorStore | |||
| from models.dataset import Dataset | |||
| class WeaviateConfig(BaseModel): | |||
| endpoint: str | |||
| api_key: Optional[str] | |||
| batch_size: int = 100 | |||
| @root_validator() | |||
| def validate_config(cls, values: dict) -> dict: | |||
| if not values['endpoint']: | |||
| raise ValueError("config WEAVIATE_ENDPOINT is required") | |||
| return values | |||
| class WeaviateVectorIndex(BaseVectorIndex): | |||
| def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list): | |||
| super().__init__(dataset, embeddings) | |||
| self._client = self._init_client(config) | |||
| self._attributes = attributes | |||
| def _init_client(self, config: WeaviateConfig) -> weaviate.Client: | |||
| auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key) | |||
| weaviate.connect.connection.has_grpc = False | |||
| try: | |||
| client = weaviate.Client( | |||
| url=config.endpoint, | |||
| auth_client_secret=auth_config, | |||
| timeout_config=(5, 60), | |||
| startup_period=None | |||
| ) | |||
| except requests.exceptions.ConnectionError: | |||
| raise ConnectionError("Vector database connection error") | |||
| client.batch.configure( | |||
| # `batch_size` takes an `int` value to enable auto-batching | |||
| # (`None` is used for manual batching) | |||
| batch_size=config.batch_size, | |||
| # dynamically update the `batch_size` based on import speed | |||
| dynamic=True, | |||
| # `timeout_retries` takes an `int` value to retry on time outs | |||
| timeout_retries=3, | |||
| ) | |||
| return client | |||
| def get_type(self) -> str: | |||
| return 'weaviate' | |||
| def get_index_name(self, dataset: Dataset) -> str: | |||
| if self.dataset.index_struct_dict: | |||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||
| if not class_prefix.endswith('_Node'): | |||
| # original class_prefix | |||
| class_prefix += '_Node' | |||
| return class_prefix | |||
| dataset_id = dataset.id | |||
| return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' | |||
| def to_index_struct(self) -> dict: | |||
| return { | |||
| "type": self.get_type(), | |||
| "vector_store": {"class_prefix": self.get_index_name(self.dataset)} | |||
| } | |||
| def create(self, texts: list[Document], **kwargs) -> BaseIndex: | |||
| uuids = self._get_uuids(texts) | |||
| self._vector_store = WeaviateVectorStore.from_documents( | |||
| texts, | |||
| self._embeddings, | |||
| client=self._client, | |||
| index_name=self.get_index_name(self.dataset), | |||
| uuids=uuids, | |||
| by_text=False | |||
| ) | |||
| return self | |||
| def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: | |||
| uuids = self._get_uuids(texts) | |||
| self._vector_store = WeaviateVectorStore.from_documents( | |||
| texts, | |||
| self._embeddings, | |||
| client=self._client, | |||
| index_name=self.get_index_name(self.dataset), | |||
| uuids=uuids, | |||
| by_text=False | |||
| ) | |||
| return self | |||
| def _get_vector_store(self) -> VectorStore: | |||
| """Only for created index.""" | |||
| if self._vector_store: | |||
| return self._vector_store | |||
| attributes = self._attributes | |||
| if self._is_origin(): | |||
| attributes = ['doc_id'] | |||
| return WeaviateVectorStore( | |||
| client=self._client, | |||
| index_name=self.get_index_name(self.dataset), | |||
| text_key='text', | |||
| embedding=self._embeddings, | |||
| attributes=attributes, | |||
| by_text=False | |||
| ) | |||
| def _get_vector_store_class(self) -> type: | |||
| return WeaviateVectorStore | |||
| def delete_by_document_id(self, document_id: str): | |||
| if self._is_origin(): | |||
| self.recreate_dataset(self.dataset) | |||
| return | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| vector_store.del_texts({ | |||
| "operator": "Equal", | |||
| "path": ["document_id"], | |||
| "valueText": document_id | |||
| }) | |||
| def delete_by_metadata_field(self, key: str, value: str): | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| vector_store.del_texts({ | |||
| "operator": "Equal", | |||
| "path": [key], | |||
| "valueText": value | |||
| }) | |||
| def delete_by_group_id(self, group_id: str): | |||
| if self._is_origin(): | |||
| self.recreate_dataset(self.dataset) | |||
| return | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| vector_store.delete() | |||
| def _is_origin(self): | |||
| if self.dataset.index_struct_dict: | |||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||
| if not class_prefix.endswith('_Node'): | |||
| # original class_prefix | |||
| return True | |||
| return False | |||
| def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs) | |||
| @@ -9,20 +9,20 @@ from typing import Optional, cast | |||
| from flask import Flask, current_app | |||
| from flask_login import current_user | |||
| from langchain.schema import Document | |||
| from langchain.text_splitter import TextSplitter | |||
| from sqlalchemy.orm.exc import ObjectDeletedError | |||
| from core.data_loader.file_extractor import FileExtractor | |||
| from core.data_loader.loader.notion import NotionLoader | |||
| from core.docstore.dataset_docstore import DatasetDocumentStore | |||
| from core.errors.error import ProviderTokenNotInitError | |||
| from core.generator.llm_generator import LLMGenerator | |||
| from core.index.index import IndexBuilder | |||
| from core.model_manager import ModelInstance, ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType, PriceType | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | |||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||
| from core.rag.index_processor.index_processor_base import BaseIndexProcessor | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from core.rag.models.document import Document | |||
| from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| @@ -31,7 +31,6 @@ from libs import helper | |||
| from models.dataset import Dataset, DatasetProcessRule, DocumentSegment | |||
| from models.dataset import Document as DatasetDocument | |||
| from models.model import UploadFile | |||
| from models.source import DataSourceBinding | |||
| from services.feature_service import FeatureService | |||
| @@ -57,38 +56,19 @@ class IndexingRunner: | |||
| processing_rule = db.session.query(DatasetProcessRule). \ | |||
| filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ | |||
| first() | |||
| # load file | |||
| text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic') | |||
| # get embedding model instance | |||
| embedding_model_instance = None | |||
| if dataset.indexing_technique == 'high_quality': | |||
| if dataset.embedding_model_provider: | |||
| embedding_model_instance = self.model_manager.get_model_instance( | |||
| tenant_id=dataset.tenant_id, | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model | |||
| ) | |||
| else: | |||
| embedding_model_instance = self.model_manager.get_default_model_instance( | |||
| tenant_id=dataset.tenant_id, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| ) | |||
| # get splitter | |||
| splitter = self._get_splitter(processing_rule, embedding_model_instance) | |||
| # split to documents | |||
| documents = self._step_split( | |||
| text_docs=text_docs, | |||
| splitter=splitter, | |||
| dataset=dataset, | |||
| dataset_document=dataset_document, | |||
| processing_rule=processing_rule | |||
| ) | |||
| self._build_index( | |||
| index_type = dataset_document.doc_form | |||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | |||
| # extract | |||
| text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) | |||
| # transform | |||
| documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict()) | |||
| # save segment | |||
| self._load_segments(dataset, dataset_document, documents) | |||
| # load | |||
| self._load( | |||
| index_processor=index_processor, | |||
| dataset=dataset, | |||
| dataset_document=dataset_document, | |||
| documents=documents | |||
| @@ -134,39 +114,19 @@ class IndexingRunner: | |||
| filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ | |||
| first() | |||
| # load file | |||
| text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic') | |||
| # get embedding model instance | |||
| embedding_model_instance = None | |||
| if dataset.indexing_technique == 'high_quality': | |||
| if dataset.embedding_model_provider: | |||
| embedding_model_instance = self.model_manager.get_model_instance( | |||
| tenant_id=dataset.tenant_id, | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model | |||
| ) | |||
| else: | |||
| embedding_model_instance = self.model_manager.get_default_model_instance( | |||
| tenant_id=dataset.tenant_id, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| ) | |||
| # get splitter | |||
| splitter = self._get_splitter(processing_rule, embedding_model_instance) | |||
| index_type = dataset_document.doc_form | |||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | |||
| # extract | |||
| text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) | |||
| # split to documents | |||
| documents = self._step_split( | |||
| text_docs=text_docs, | |||
| splitter=splitter, | |||
| dataset=dataset, | |||
| dataset_document=dataset_document, | |||
| processing_rule=processing_rule | |||
| ) | |||
| # transform | |||
| documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict()) | |||
| # save segment | |||
| self._load_segments(dataset, dataset_document, documents) | |||
| # build index | |||
| self._build_index( | |||
| # load | |||
| self._load( | |||
| index_processor=index_processor, | |||
| dataset=dataset, | |||
| dataset_document=dataset_document, | |||
| documents=documents | |||
| @@ -220,7 +180,15 @@ class IndexingRunner: | |||
| documents.append(document) | |||
| # build index | |||
| self._build_index( | |||
| # get the process rule | |||
| processing_rule = db.session.query(DatasetProcessRule). \ | |||
| filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ | |||
| first() | |||
| index_type = dataset_document.doc_form | |||
| index_processor = IndexProcessorFactory(index_type, processing_rule.to_dict()).init_index_processor() | |||
| self._load( | |||
| index_processor=index_processor, | |||
| dataset=dataset, | |||
| dataset_document=dataset_document, | |||
| documents=documents | |||
| @@ -239,16 +207,16 @@ class IndexingRunner: | |||
| dataset_document.stopped_at = datetime.datetime.utcnow() | |||
| db.session.commit() | |||
| def file_indexing_estimate(self, tenant_id: str, file_details: list[UploadFile], tmp_processing_rule: dict, | |||
| doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, | |||
| indexing_technique: str = 'economy') -> dict: | |||
| def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict, | |||
| doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, | |||
| indexing_technique: str = 'economy') -> dict: | |||
| """ | |||
| Estimate the indexing for the document. | |||
| """ | |||
| # check document limit | |||
| features = FeatureService.get_features(tenant_id) | |||
| if features.billing.enabled: | |||
| count = len(file_details) | |||
| count = len(extract_settings) | |||
| batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) | |||
| if count > batch_upload_limit: | |||
| raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") | |||
| @@ -284,16 +252,18 @@ class IndexingRunner: | |||
| total_segments = 0 | |||
| total_price = 0 | |||
| currency = 'USD' | |||
| for file_detail in file_details: | |||
| index_type = doc_form | |||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | |||
| all_text_docs = [] | |||
| for extract_setting in extract_settings: | |||
| # extract | |||
| text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) | |||
| all_text_docs.extend(text_docs) | |||
| processing_rule = DatasetProcessRule( | |||
| mode=tmp_processing_rule["mode"], | |||
| rules=json.dumps(tmp_processing_rule["rules"]) | |||
| ) | |||
| # load data from file | |||
| text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic') | |||
| # get splitter | |||
| splitter = self._get_splitter(processing_rule, embedding_model_instance) | |||
| @@ -305,7 +275,6 @@ class IndexingRunner: | |||
| ) | |||
| total_segments += len(documents) | |||
| for document in documents: | |||
| if len(preview_texts) < 5: | |||
| preview_texts.append(document.page_content) | |||
| @@ -364,154 +333,8 @@ class IndexingRunner: | |||
| "preview": preview_texts | |||
| } | |||
| def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, | |||
| doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, | |||
| indexing_technique: str = 'economy') -> dict: | |||
| """ | |||
| Estimate the indexing for the document. | |||
| """ | |||
| # check document limit | |||
| features = FeatureService.get_features(tenant_id) | |||
| if features.billing.enabled: | |||
| count = len(notion_info_list) | |||
| batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) | |||
| if count > batch_upload_limit: | |||
| raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") | |||
| embedding_model_instance = None | |||
| if dataset_id: | |||
| dataset = Dataset.query.filter_by( | |||
| id=dataset_id | |||
| ).first() | |||
| if not dataset: | |||
| raise ValueError('Dataset not found.') | |||
| if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality': | |||
| if dataset.embedding_model_provider: | |||
| embedding_model_instance = self.model_manager.get_model_instance( | |||
| tenant_id=tenant_id, | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model | |||
| ) | |||
| else: | |||
| embedding_model_instance = self.model_manager.get_default_model_instance( | |||
| tenant_id=tenant_id, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| ) | |||
| else: | |||
| if indexing_technique == 'high_quality': | |||
| embedding_model_instance = self.model_manager.get_default_model_instance( | |||
| tenant_id=tenant_id, | |||
| model_type=ModelType.TEXT_EMBEDDING | |||
| ) | |||
| # load data from notion | |||
| tokens = 0 | |||
| preview_texts = [] | |||
| total_segments = 0 | |||
| total_price = 0 | |||
| currency = 'USD' | |||
| for notion_info in notion_info_list: | |||
| workspace_id = notion_info['workspace_id'] | |||
| data_source_binding = DataSourceBinding.query.filter( | |||
| db.and_( | |||
| DataSourceBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceBinding.provider == 'notion', | |||
| DataSourceBinding.disabled == False, | |||
| DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' | |||
| ) | |||
| ).first() | |||
| if not data_source_binding: | |||
| raise ValueError('Data source binding not found.') | |||
| for page in notion_info['pages']: | |||
| loader = NotionLoader( | |||
| notion_access_token=data_source_binding.access_token, | |||
| notion_workspace_id=workspace_id, | |||
| notion_obj_id=page['page_id'], | |||
| notion_page_type=page['type'] | |||
| ) | |||
| documents = loader.load() | |||
| processing_rule = DatasetProcessRule( | |||
| mode=tmp_processing_rule["mode"], | |||
| rules=json.dumps(tmp_processing_rule["rules"]) | |||
| ) | |||
| # get splitter | |||
| splitter = self._get_splitter(processing_rule, embedding_model_instance) | |||
| # split to documents | |||
| documents = self._split_to_documents_for_estimate( | |||
| text_docs=documents, | |||
| splitter=splitter, | |||
| processing_rule=processing_rule | |||
| ) | |||
| total_segments += len(documents) | |||
| embedding_model_type_instance = None | |||
| if embedding_model_instance: | |||
| embedding_model_type_instance = embedding_model_instance.model_type_instance | |||
| embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) | |||
| for document in documents: | |||
| if len(preview_texts) < 5: | |||
| preview_texts.append(document.page_content) | |||
| if indexing_technique == 'high_quality' and embedding_model_type_instance: | |||
| tokens += embedding_model_type_instance.get_num_tokens( | |||
| model=embedding_model_instance.model, | |||
| credentials=embedding_model_instance.credentials, | |||
| texts=[document.page_content] | |||
| ) | |||
| if doc_form and doc_form == 'qa_model': | |||
| model_instance = self.model_manager.get_default_model_instance( | |||
| tenant_id=tenant_id, | |||
| model_type=ModelType.LLM | |||
| ) | |||
| model_type_instance = model_instance.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| if len(preview_texts) > 0: | |||
| # qa model document | |||
| response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], | |||
| doc_language) | |||
| document_qa_list = self.format_split_text(response) | |||
| price_info = model_type_instance.get_price( | |||
| model=model_instance.model, | |||
| credentials=model_instance.credentials, | |||
| price_type=PriceType.INPUT, | |||
| tokens=total_segments * 2000, | |||
| ) | |||
| return { | |||
| "total_segments": total_segments * 20, | |||
| "tokens": total_segments * 2000, | |||
| "total_price": '{:f}'.format(price_info.total_amount), | |||
| "currency": price_info.currency, | |||
| "qa_preview": document_qa_list, | |||
| "preview": preview_texts | |||
| } | |||
| if embedding_model_instance: | |||
| embedding_model_type_instance = embedding_model_instance.model_type_instance | |||
| embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) | |||
| embedding_price_info = embedding_model_type_instance.get_price( | |||
| model=embedding_model_instance.model, | |||
| credentials=embedding_model_instance.credentials, | |||
| price_type=PriceType.INPUT, | |||
| tokens=tokens | |||
| ) | |||
| total_price = '{:f}'.format(embedding_price_info.total_amount) | |||
| currency = embedding_price_info.currency | |||
| return { | |||
| "total_segments": total_segments, | |||
| "tokens": tokens, | |||
| "total_price": total_price, | |||
| "currency": currency, | |||
| "preview": preview_texts | |||
| } | |||
| def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> list[Document]: | |||
| def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \ | |||
| -> list[Document]: | |||
| # load file | |||
| if dataset_document.data_source_type not in ["upload_file", "notion_import"]: | |||
| return [] | |||
| @@ -527,11 +350,27 @@ class IndexingRunner: | |||
| one_or_none() | |||
| if file_detail: | |||
| text_docs = FileExtractor.load(file_detail, is_automatic=automatic) | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="upload_file", | |||
| upload_file=file_detail, | |||
| document_model=dataset_document.doc_form | |||
| ) | |||
| text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) | |||
| elif dataset_document.data_source_type == 'notion_import': | |||
| loader = NotionLoader.from_document(dataset_document) | |||
| text_docs = loader.load() | |||
| if (not data_source_info or 'notion_workspace_id' not in data_source_info | |||
| or 'notion_page_id' not in data_source_info): | |||
| raise ValueError("no notion import info found") | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="notion_import", | |||
| notion_info={ | |||
| "notion_workspace_id": data_source_info['notion_workspace_id'], | |||
| "notion_obj_id": data_source_info['notion_page_id'], | |||
| "notion_page_type": data_source_info['notion_page_type'], | |||
| "document": dataset_document | |||
| }, | |||
| document_model=dataset_document.doc_form | |||
| ) | |||
| text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) | |||
| # update document status to splitting | |||
| self._update_document_index_status( | |||
| document_id=dataset_document.id, | |||
| @@ -545,8 +384,6 @@ class IndexingRunner: | |||
| # replace doc id to document model id | |||
| text_docs = cast(list[Document], text_docs) | |||
| for text_doc in text_docs: | |||
| # remove invalid symbol | |||
| text_doc.page_content = self.filter_string(text_doc.page_content) | |||
| text_doc.metadata['document_id'] = dataset_document.id | |||
| text_doc.metadata['dataset_id'] = dataset_document.dataset_id | |||
| @@ -787,12 +624,12 @@ class IndexingRunner: | |||
| for q, a in matches if q and a | |||
| ] | |||
| def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]) -> None: | |||
| def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset, | |||
| dataset_document: DatasetDocument, documents: list[Document]) -> None: | |||
| """ | |||
| Build the index for the document. | |||
| insert index and update document/segment status to completed | |||
| """ | |||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| keyword_table_index = IndexBuilder.get_index(dataset, 'economy') | |||
| embedding_model_instance = None | |||
| if dataset.indexing_technique == 'high_quality': | |||
| embedding_model_instance = self.model_manager.get_model_instance( | |||
| @@ -825,13 +662,8 @@ class IndexingRunner: | |||
| ) | |||
| for document in chunk_documents | |||
| ) | |||
| # save vector index | |||
| if vector_index: | |||
| vector_index.add_texts(chunk_documents) | |||
| # save keyword index | |||
| keyword_table_index.add_texts(chunk_documents) | |||
| # load index | |||
| index_processor.load(dataset, chunk_documents) | |||
| document_ids = [document.metadata['doc_id'] for document in chunk_documents] | |||
| db.session.query(DocumentSegment).filter( | |||
| @@ -911,14 +743,64 @@ class IndexingRunner: | |||
| ) | |||
| documents.append(document) | |||
| # save vector index | |||
| index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| if index: | |||
| index.add_texts(documents, duplicate_check=True) | |||
| # save keyword index | |||
| index = IndexBuilder.get_index(dataset, 'economy') | |||
| if index: | |||
| index.add_texts(documents) | |||
| index_type = dataset.doc_form | |||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | |||
| index_processor.load(dataset, documents) | |||
| def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset, | |||
| text_docs: list[Document], process_rule: dict) -> list[Document]: | |||
| # get embedding model instance | |||
| embedding_model_instance = None | |||
| if dataset.indexing_technique == 'high_quality': | |||
| if dataset.embedding_model_provider: | |||
| embedding_model_instance = self.model_manager.get_model_instance( | |||
| tenant_id=dataset.tenant_id, | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model | |||
| ) | |||
| else: | |||
| embedding_model_instance = self.model_manager.get_default_model_instance( | |||
| tenant_id=dataset.tenant_id, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| ) | |||
| documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance, | |||
| process_rule=process_rule) | |||
| return documents | |||
| def _load_segments(self, dataset, dataset_document, documents): | |||
| # save node to document segment | |||
| doc_store = DatasetDocumentStore( | |||
| dataset=dataset, | |||
| user_id=dataset_document.created_by, | |||
| document_id=dataset_document.id | |||
| ) | |||
| # add document segments | |||
| doc_store.add_documents(documents) | |||
| # update document status to indexing | |||
| cur_time = datetime.datetime.utcnow() | |||
| self._update_document_index_status( | |||
| document_id=dataset_document.id, | |||
| after_indexing_status="indexing", | |||
| extra_update_params={ | |||
| DatasetDocument.cleaning_completed_at: cur_time, | |||
| DatasetDocument.splitting_completed_at: cur_time, | |||
| } | |||
| ) | |||
| # update segment status to indexing | |||
| self._update_segments_by_document( | |||
| dataset_document_id=dataset_document.id, | |||
| update_params={ | |||
| DocumentSegment.status: "indexing", | |||
| DocumentSegment.indexing_at: datetime.datetime.utcnow() | |||
| } | |||
| ) | |||
| pass | |||
| class DocumentIsPausedException(Exception): | |||
| @@ -0,0 +1,38 @@ | |||
| import re | |||
| class CleanProcessor: | |||
| @classmethod | |||
| def clean(cls, text: str, process_rule: dict) -> str: | |||
| # default clean | |||
| # remove invalid symbol | |||
| text = re.sub(r'<\|', '<', text) | |||
| text = re.sub(r'\|>', '>', text) | |||
| text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text) | |||
| # Unicode U+FFFE | |||
| text = re.sub('\uFFFE', '', text) | |||
| rules = process_rule['rules'] if process_rule else None | |||
| if 'pre_processing_rules' in rules: | |||
| pre_processing_rules = rules["pre_processing_rules"] | |||
| for pre_processing_rule in pre_processing_rules: | |||
| if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True: | |||
| # Remove extra spaces | |||
| pattern = r'\n{3,}' | |||
| text = re.sub(pattern, '\n\n', text) | |||
| pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}' | |||
| text = re.sub(pattern, ' ', text) | |||
| elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True: | |||
| # Remove email | |||
| pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)' | |||
| text = re.sub(pattern, '', text) | |||
| # Remove URL | |||
| pattern = r'https?://[^\s]+' | |||
| text = re.sub(pattern, '', text) | |||
| return text | |||
| def filter_string(self, text): | |||
| return text | |||
| @@ -0,0 +1,12 @@ | |||
| """Abstract interface for document cleaner implementations.""" | |||
| from abc import ABC, abstractmethod | |||
| class BaseCleaner(ABC): | |||
| """Interface for clean chunk content. | |||
| """ | |||
| @abstractmethod | |||
| def clean(self, content: str): | |||
| raise NotImplementedError | |||
| @@ -0,0 +1,12 @@ | |||
| """Abstract interface for document clean implementations.""" | |||
| from core.rag.cleaner.cleaner_base import BaseCleaner | |||
| class UnstructuredNonAsciiCharsCleaner(BaseCleaner): | |||
| def clean(self, content) -> str: | |||
| """clean document content.""" | |||
| from unstructured.cleaners.core import clean_extra_whitespace | |||
| # Returns "ITEM 1A: RISK FACTORS" | |||
| return clean_extra_whitespace(content) | |||
| @@ -0,0 +1,15 @@ | |||
| """Abstract interface for document clean implementations.""" | |||
| from core.rag.cleaner.cleaner_base import BaseCleaner | |||
| class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner): | |||
| def clean(self, content) -> str: | |||
| """clean document content.""" | |||
| import re | |||
| from unstructured.cleaners.core import group_broken_paragraphs | |||
| para_split_re = re.compile(r"(\s*\n\s*){3}") | |||
| return group_broken_paragraphs(content, paragraph_split=para_split_re) | |||
| @@ -0,0 +1,12 @@ | |||
| """Abstract interface for document clean implementations.""" | |||
| from core.rag.cleaner.cleaner_base import BaseCleaner | |||
| class UnstructuredNonAsciiCharsCleaner(BaseCleaner): | |||
| def clean(self, content) -> str: | |||
| """clean document content.""" | |||
| from unstructured.cleaners.core import clean_non_ascii_chars | |||
| # Returns "This text containsnon-ascii characters!" | |||
| return clean_non_ascii_chars(content) | |||
| @@ -0,0 +1,11 @@ | |||
| """Abstract interface for document clean implementations.""" | |||
| from core.rag.cleaner.cleaner_base import BaseCleaner | |||
| class UnstructuredNonAsciiCharsCleaner(BaseCleaner): | |||
| def clean(self, content) -> str: | |||
| """Replaces unicode quote characters, such as the \x91 character in a string.""" | |||
| from unstructured.cleaners.core import replace_unicode_quotes | |||
| return replace_unicode_quotes(content) | |||
| @@ -0,0 +1,11 @@ | |||
| """Abstract interface for document clean implementations.""" | |||
| from core.rag.cleaner.cleaner_base import BaseCleaner | |||
| class UnstructuredTranslateTextCleaner(BaseCleaner): | |||
| def clean(self, content) -> str: | |||
| """clean document content.""" | |||
| from unstructured.cleaners.translate import translate_text | |||
| return translate_text(content) | |||
| @@ -0,0 +1,49 @@ | |||
| from typing import Optional | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from core.rag.data_post_processor.reorder import ReorderRunner | |||
| from core.rag.models.document import Document | |||
| from core.rerank.rerank import RerankRunner | |||
| class DataPostProcessor: | |||
| """Interface for data post-processing document. | |||
| """ | |||
| def __init__(self, tenant_id: str, reranking_model: dict, reorder_enabled: bool = False): | |||
| self.rerank_runner = self._get_rerank_runner(reranking_model, tenant_id) | |||
| self.reorder_runner = self._get_reorder_runner(reorder_enabled) | |||
| def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, | |||
| top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: | |||
| if self.rerank_runner: | |||
| documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user) | |||
| if self.reorder_runner: | |||
| documents = self.reorder_runner.run(documents) | |||
| return documents | |||
| def _get_rerank_runner(self, reranking_model: dict, tenant_id: str) -> Optional[RerankRunner]: | |||
| if reranking_model: | |||
| try: | |||
| model_manager = ModelManager() | |||
| rerank_model_instance = model_manager.get_model_instance( | |||
| tenant_id=tenant_id, | |||
| provider=reranking_model['reranking_provider_name'], | |||
| model_type=ModelType.RERANK, | |||
| model=reranking_model['reranking_model_name'] | |||
| ) | |||
| except InvokeAuthorizationError: | |||
| return None | |||
| return RerankRunner(rerank_model_instance) | |||
| return None | |||
| def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]: | |||
| if reorder_enabled: | |||
| return ReorderRunner() | |||
| return None | |||
| @@ -0,0 +1,19 @@ | |||
| from langchain.schema import Document | |||
| class ReorderRunner: | |||
| def run(self, documents: list[Document]) -> list[Document]: | |||
| # Retrieve elements from odd indices (0, 2, 4, etc.) of the documents list | |||
| odd_elements = documents[::2] | |||
| # Retrieve elements from even indices (1, 3, 5, etc.) of the documents list | |||
| even_elements = documents[1::2] | |||
| # Reverse the list of elements from even indices | |||
| even_elements_reversed = even_elements[::-1] | |||
| new_documents = odd_elements + even_elements_reversed | |||
| return new_documents | |||
| @@ -0,0 +1,21 @@ | |||
| from abc import ABC, abstractmethod | |||
| class Embeddings(ABC): | |||
| """Interface for embedding models.""" | |||
| @abstractmethod | |||
| def embed_documents(self, texts: list[str]) -> list[list[float]]: | |||
| """Embed search docs.""" | |||
| @abstractmethod | |||
| def embed_query(self, text: str) -> list[float]: | |||
| """Embed query text.""" | |||
| async def aembed_documents(self, texts: list[str]) -> list[list[float]]: | |||
| """Asynchronous Embed search docs.""" | |||
| raise NotImplementedError | |||
| async def aembed_query(self, text: str) -> list[float]: | |||
| """Asynchronous Embed query text.""" | |||
| raise NotImplementedError | |||
| @@ -2,11 +2,11 @@ import json | |||
| from collections import defaultdict | |||
| from typing import Any, Optional | |||
| from langchain.schema import BaseRetriever, Document | |||
| from pydantic import BaseModel, Extra, Field | |||
| from pydantic import BaseModel | |||
| from core.index.base import BaseIndex | |||
| from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler | |||
| from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler | |||
| from core.rag.datasource.keyword.keyword_base import BaseKeyword | |||
| from core.rag.models.document import Document | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment | |||
| @@ -15,59 +15,19 @@ class KeywordTableConfig(BaseModel): | |||
| max_keywords_per_chunk: int = 10 | |||
| class KeywordTableIndex(BaseIndex): | |||
| def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()): | |||
| class Jieba(BaseKeyword): | |||
| def __init__(self, dataset: Dataset): | |||
| super().__init__(dataset) | |||
| self._config = config | |||
| self._config = KeywordTableConfig() | |||
| def create(self, texts: list[Document], **kwargs) -> BaseIndex: | |||
| def create(self, texts: list[Document], **kwargs) -> BaseKeyword: | |||
| keyword_table_handler = JiebaKeywordTableHandler() | |||
| keyword_table = {} | |||
| for text in texts: | |||
| keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) | |||
| self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) | |||
| keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) | |||
| dataset_keyword_table = DatasetKeywordTable( | |||
| dataset_id=self.dataset.id, | |||
| keyword_table=json.dumps({ | |||
| '__type__': 'keyword_table', | |||
| '__data__': { | |||
| "index_id": self.dataset.id, | |||
| "summary": None, | |||
| "table": {} | |||
| } | |||
| }, cls=SetEncoder) | |||
| ) | |||
| db.session.add(dataset_keyword_table) | |||
| db.session.commit() | |||
| self._save_dataset_keyword_table(keyword_table) | |||
| return self | |||
| def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: | |||
| keyword_table_handler = JiebaKeywordTableHandler() | |||
| keyword_table = {} | |||
| keyword_table = self._get_dataset_keyword_table() | |||
| for text in texts: | |||
| keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) | |||
| self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) | |||
| keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) | |||
| dataset_keyword_table = DatasetKeywordTable( | |||
| dataset_id=self.dataset.id, | |||
| keyword_table=json.dumps({ | |||
| '__type__': 'keyword_table', | |||
| '__data__': { | |||
| "index_id": self.dataset.id, | |||
| "summary": None, | |||
| "table": {} | |||
| } | |||
| }, cls=SetEncoder) | |||
| ) | |||
| db.session.add(dataset_keyword_table) | |||
| db.session.commit() | |||
| self._save_dataset_keyword_table(keyword_table) | |||
| return self | |||
| @@ -76,8 +36,13 @@ class KeywordTableIndex(BaseIndex): | |||
| keyword_table_handler = JiebaKeywordTableHandler() | |||
| keyword_table = self._get_dataset_keyword_table() | |||
| for text in texts: | |||
| keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) | |||
| keywords_list = kwargs.get('keywords_list', None) | |||
| for i in range(len(texts)): | |||
| text = texts[i] | |||
| if keywords_list: | |||
| keywords = keywords_list[i] | |||
| else: | |||
| keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) | |||
| self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) | |||
| keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) | |||
| @@ -107,20 +72,13 @@ class KeywordTableIndex(BaseIndex): | |||
| self._save_dataset_keyword_table(keyword_table) | |||
| def delete_by_metadata_field(self, key: str, value: str): | |||
| pass | |||
| def get_retriever(self, **kwargs: Any) -> BaseRetriever: | |||
| return KeywordTableRetriever(index=self, **kwargs) | |||
| def search( | |||
| self, query: str, | |||
| **kwargs: Any | |||
| ) -> list[Document]: | |||
| keyword_table = self._get_dataset_keyword_table() | |||
| search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {} | |||
| k = search_kwargs.get('k') if search_kwargs.get('k') else 4 | |||
| k = kwargs.get('top_k', 4) | |||
| sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k) | |||
| @@ -150,12 +108,6 @@ class KeywordTableIndex(BaseIndex): | |||
| db.session.delete(dataset_keyword_table) | |||
| db.session.commit() | |||
| def delete_by_group_id(self, group_id: str) -> None: | |||
| dataset_keyword_table = self.dataset.dataset_keyword_table | |||
| if dataset_keyword_table: | |||
| db.session.delete(dataset_keyword_table) | |||
| db.session.commit() | |||
| def _save_dataset_keyword_table(self, keyword_table): | |||
| keyword_table_dict = { | |||
| '__type__': 'keyword_table', | |||
| @@ -242,6 +194,7 @@ class KeywordTableIndex(BaseIndex): | |||
| ).first() | |||
| if document_segment: | |||
| document_segment.keywords = keywords | |||
| db.session.add(document_segment) | |||
| db.session.commit() | |||
| def create_segment_keywords(self, node_id: str, keywords: list[str]): | |||
| @@ -272,31 +225,6 @@ class KeywordTableIndex(BaseIndex): | |||
| self._save_dataset_keyword_table(keyword_table) | |||
| class KeywordTableRetriever(BaseRetriever, BaseModel): | |||
| index: KeywordTableIndex | |||
| search_kwargs: dict = Field(default_factory=dict) | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| extra = Extra.forbid | |||
| arbitrary_types_allowed = True | |||
| def get_relevant_documents(self, query: str) -> list[Document]: | |||
| """Get documents relevant for a query. | |||
| Args: | |||
| query: string to find relevant documents for | |||
| Returns: | |||
| List of relevant documents | |||
| """ | |||
| return self.index.search(query, **self.search_kwargs) | |||
| async def aget_relevant_documents(self, query: str) -> list[Document]: | |||
| raise NotImplementedError("KeywordTableRetriever does not support async") | |||
| class SetEncoder(json.JSONEncoder): | |||
| def default(self, obj): | |||
| if isinstance(obj, set): | |||
| @@ -3,7 +3,7 @@ import re | |||
| import jieba | |||
| from jieba.analyse import default_tfidf | |||
| from core.index.keyword_table_index.stopwords import STOPWORDS | |||
| from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS | |||
| class JiebaKeywordTableHandler: | |||
| @@ -3,22 +3,17 @@ from __future__ import annotations | |||
| from abc import ABC, abstractmethod | |||
| from typing import Any | |||
| from langchain.schema import BaseRetriever, Document | |||
| from core.rag.models.document import Document | |||
| from models.dataset import Dataset | |||
| class BaseIndex(ABC): | |||
| class BaseKeyword(ABC): | |||
| def __init__(self, dataset: Dataset): | |||
| self.dataset = dataset | |||
| @abstractmethod | |||
| def create(self, texts: list[Document], **kwargs) -> BaseIndex: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: | |||
| def create(self, texts: list[Document], **kwargs) -> BaseKeyword: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| @@ -34,31 +29,18 @@ class BaseIndex(ABC): | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def delete_by_metadata_field(self, key: str, value: str) -> None: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def delete_by_group_id(self, group_id: str) -> None: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def delete_by_document_id(self, document_id: str): | |||
| def delete_by_document_id(self, document_id: str) -> None: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def get_retriever(self, **kwargs: Any) -> BaseRetriever: | |||
| def delete(self) -> None: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def search( | |||
| self, query: str, | |||
| **kwargs: Any | |||
| ) -> list[Document]: | |||
| raise NotImplementedError | |||
| def delete(self) -> None: | |||
| raise NotImplementedError | |||
| def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: | |||
| for text in texts: | |||
| doc_id = text.metadata['doc_id'] | |||
| @@ -0,0 +1,60 @@ | |||
| from typing import Any, cast | |||
| from flask import current_app | |||
| from core.rag.datasource.keyword.jieba.jieba import Jieba | |||
| from core.rag.datasource.keyword.keyword_base import BaseKeyword | |||
| from core.rag.models.document import Document | |||
| from models.dataset import Dataset | |||
| class Keyword: | |||
| def __init__(self, dataset: Dataset): | |||
| self._dataset = dataset | |||
| self._keyword_processor = self._init_keyword() | |||
| def _init_keyword(self) -> BaseKeyword: | |||
| config = cast(dict, current_app.config) | |||
| keyword_type = config.get('KEYWORD_STORE') | |||
| if not keyword_type: | |||
| raise ValueError("Keyword store must be specified.") | |||
| if keyword_type == "jieba": | |||
| return Jieba( | |||
| dataset=self._dataset | |||
| ) | |||
| else: | |||
| raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") | |||
| def create(self, texts: list[Document], **kwargs): | |||
| self._keyword_processor.create(texts, **kwargs) | |||
| def add_texts(self, texts: list[Document], **kwargs): | |||
| self._keyword_processor.add_texts(texts, **kwargs) | |||
| def text_exists(self, id: str) -> bool: | |||
| return self._keyword_processor.text_exists(id) | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| self._keyword_processor.delete_by_ids(ids) | |||
| def delete_by_document_id(self, document_id: str) -> None: | |||
| self._keyword_processor.delete_by_document_id(document_id) | |||
| def delete(self) -> None: | |||
| self._keyword_processor.delete() | |||
| def search( | |||
| self, query: str, | |||
| **kwargs: Any | |||
| ) -> list[Document]: | |||
| return self._keyword_processor.search(query, **kwargs) | |||
| def __getattr__(self, name): | |||
| if self._keyword_processor is not None: | |||
| method = getattr(self._keyword_processor, name) | |||
| if callable(method): | |||
| return method | |||
| raise AttributeError(f"'Keyword' object has no attribute '{name}'") | |||
| @@ -0,0 +1,165 @@ | |||
| import threading | |||
| from typing import Optional | |||
| from flask import Flask, current_app | |||
| from flask_login import current_user | |||
| from core.rag.data_post_processor.data_post_processor import DataPostProcessor | |||
| from core.rag.datasource.keyword.keyword_factory import Keyword | |||
| from core.rag.datasource.vdb.vector_factory import Vector | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset | |||
| default_retrieval_model = { | |||
| 'search_method': 'semantic_search', | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| 'reranking_model_name': '' | |||
| }, | |||
| 'top_k': 2, | |||
| 'score_threshold_enabled': False | |||
| } | |||
| class RetrievalService: | |||
| @classmethod | |||
| def retrieve(cls, retrival_method: str, dataset_id: str, query: str, | |||
| top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None): | |||
| all_documents = [] | |||
| threads = [] | |||
| # retrieval_model source with keyword | |||
| if retrival_method == 'keyword_search': | |||
| keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset_id': dataset_id, | |||
| 'query': query, | |||
| 'top_k': top_k | |||
| }) | |||
| threads.append(keyword_thread) | |||
| keyword_thread.start() | |||
| # retrieval_model source with semantic | |||
| if retrival_method == 'semantic_search' or retrival_method == 'hybrid_search': | |||
| embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset_id': dataset_id, | |||
| 'query': query, | |||
| 'top_k': top_k, | |||
| 'score_threshold': score_threshold, | |||
| 'reranking_model': reranking_model, | |||
| 'all_documents': all_documents, | |||
| 'retrival_method': retrival_method | |||
| }) | |||
| threads.append(embedding_thread) | |||
| embedding_thread.start() | |||
| # retrieval source with full text | |||
| if retrival_method == 'full_text_search' or retrival_method == 'hybrid_search': | |||
| full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset_id': dataset_id, | |||
| 'query': query, | |||
| 'retrival_method': retrival_method, | |||
| 'score_threshold': score_threshold, | |||
| 'top_k': top_k, | |||
| 'reranking_model': reranking_model, | |||
| 'all_documents': all_documents | |||
| }) | |||
| threads.append(full_text_index_thread) | |||
| full_text_index_thread.start() | |||
| for thread in threads: | |||
| thread.join() | |||
| if retrival_method == 'hybrid_search': | |||
| data_post_processor = DataPostProcessor(str(current_user.current_tenant_id), reranking_model, False) | |||
| all_documents = data_post_processor.invoke( | |||
| query=query, | |||
| documents=all_documents, | |||
| score_threshold=score_threshold, | |||
| top_n=top_k | |||
| ) | |||
| return all_documents | |||
| @classmethod | |||
| def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str, | |||
| top_k: int, all_documents: list): | |||
| with flask_app.app_context(): | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| keyword = Keyword( | |||
| dataset=dataset | |||
| ) | |||
| documents = keyword.search( | |||
| query, | |||
| k=top_k | |||
| ) | |||
| all_documents.extend(documents) | |||
| @classmethod | |||
| def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, | |||
| top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], | |||
| all_documents: list, retrival_method: str): | |||
| with flask_app.app_context(): | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| vector = Vector( | |||
| dataset=dataset | |||
| ) | |||
| documents = vector.search_by_vector( | |||
| query, | |||
| search_type='similarity_score_threshold', | |||
| k=top_k, | |||
| score_threshold=score_threshold, | |||
| filter={ | |||
| 'group_id': [dataset.id] | |||
| } | |||
| ) | |||
| if documents: | |||
| if reranking_model and retrival_method == 'semantic_search': | |||
| data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) | |||
| all_documents.extend(data_post_processor.invoke( | |||
| query=query, | |||
| documents=documents, | |||
| score_threshold=score_threshold, | |||
| top_n=len(documents) | |||
| )) | |||
| else: | |||
| all_documents.extend(documents) | |||
| @classmethod | |||
| def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, | |||
| top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], | |||
| all_documents: list, retrival_method: str): | |||
| with flask_app.app_context(): | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| vector_processor = Vector( | |||
| dataset=dataset, | |||
| ) | |||
| documents = vector_processor.search_by_full_text( | |||
| query, | |||
| top_k=top_k | |||
| ) | |||
| if documents: | |||
| if reranking_model and retrival_method == 'full_text_search': | |||
| data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) | |||
| all_documents.extend(data_post_processor.invoke( | |||
| query=query, | |||
| documents=documents, | |||
| score_threshold=score_threshold, | |||
| top_n=len(documents) | |||
| )) | |||
| else: | |||
| all_documents.extend(documents) | |||
| @@ -0,0 +1,10 @@ | |||
| from enum import Enum | |||
| class Field(Enum): | |||
| CONTENT_KEY = "page_content" | |||
| METADATA_KEY = "metadata" | |||
| GROUP_KEY = "group_id" | |||
| VECTOR = "vector" | |||
| TEXT_KEY = "text" | |||
| PRIMARY_KEY = " id" | |||
| @@ -0,0 +1,214 @@ | |||
| import logging | |||
| from typing import Any, Optional | |||
| from uuid import uuid4 | |||
| from pydantic import BaseModel, root_validator | |||
| from pymilvus import MilvusClient, MilvusException, connections | |||
| from core.rag.datasource.vdb.field import Field | |||
| from core.rag.datasource.vdb.vector_base import BaseVector | |||
| from core.rag.models.document import Document | |||
| logger = logging.getLogger(__name__) | |||
| class MilvusConfig(BaseModel): | |||
| host: str | |||
| port: int | |||
| user: str | |||
| password: str | |||
| secure: bool = False | |||
| batch_size: int = 100 | |||
| @root_validator() | |||
| def validate_config(cls, values: dict) -> dict: | |||
| if not values['host']: | |||
| raise ValueError("config MILVUS_HOST is required") | |||
| if not values['port']: | |||
| raise ValueError("config MILVUS_PORT is required") | |||
| if not values['user']: | |||
| raise ValueError("config MILVUS_USER is required") | |||
| if not values['password']: | |||
| raise ValueError("config MILVUS_PASSWORD is required") | |||
| return values | |||
| def to_milvus_params(self): | |||
| return { | |||
| 'host': self.host, | |||
| 'port': self.port, | |||
| 'user': self.user, | |||
| 'password': self.password, | |||
| 'secure': self.secure | |||
| } | |||
| class MilvusVector(BaseVector): | |||
| def __init__(self, collection_name: str, config: MilvusConfig): | |||
| super().__init__(collection_name) | |||
| self._client_config = config | |||
| self._client = self._init_client(config) | |||
| self._consistency_level = 'Session' | |||
| self._fields = [] | |||
| def get_type(self) -> str: | |||
| return 'milvus' | |||
| def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | |||
| index_params = { | |||
| 'metric_type': 'IP', | |||
| 'index_type': "HNSW", | |||
| 'params': {"M": 8, "efConstruction": 64} | |||
| } | |||
| metadatas = [d.metadata for d in texts] | |||
| # Grab the existing collection if it exists | |||
| from pymilvus import utility | |||
| alias = uuid4().hex | |||
| if self._client_config.secure: | |||
| uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) | |||
| else: | |||
| uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) | |||
| connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password) | |||
| if not utility.has_collection(self._collection_name, using=alias): | |||
| self.create_collection(embeddings, metadatas, index_params) | |||
| self.add_texts(texts, embeddings) | |||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | |||
| insert_dict_list = [] | |||
| for i in range(len(documents)): | |||
| insert_dict = { | |||
| Field.CONTENT_KEY.value: documents[i].page_content, | |||
| Field.VECTOR.value: embeddings[i], | |||
| Field.METADATA_KEY.value: documents[i].metadata | |||
| } | |||
| insert_dict_list.append(insert_dict) | |||
| # Total insert count | |||
| total_count = len(insert_dict_list) | |||
| pks: list[str] = [] | |||
| for i in range(0, total_count, 1000): | |||
| batch_insert_list = insert_dict_list[i:i + 1000] | |||
| # Insert into the collection. | |||
| try: | |||
| ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list) | |||
| pks.extend(ids) | |||
| except MilvusException as e: | |||
| logger.error( | |||
| "Failed to insert batch starting at entity: %s/%s", i, total_count | |||
| ) | |||
| raise e | |||
| return pks | |||
| def delete_by_document_id(self, document_id: str): | |||
| ids = self.get_ids_by_metadata_field('document_id', document_id) | |||
| if ids: | |||
| self._client.delete(collection_name=self._collection_name, pks=ids) | |||
| def get_ids_by_metadata_field(self, key: str, value: str): | |||
| result = self._client.query(collection_name=self._collection_name, | |||
| filter=f'metadata["{key}"] == "{value}"', | |||
| output_fields=["id"]) | |||
| if result: | |||
| return [item["id"] for item in result] | |||
| else: | |||
| return None | |||
| def delete_by_metadata_field(self, key: str, value: str): | |||
| ids = self.get_ids_by_metadata_field(key, value) | |||
| if ids: | |||
| self._client.delete(collection_name=self._collection_name, pks=ids) | |||
| def delete_by_ids(self, doc_ids: list[str]) -> None: | |||
| self._client.delete(collection_name=self._collection_name, pks=doc_ids) | |||
| def delete(self) -> None: | |||
| from pymilvus import utility | |||
| utility.drop_collection(self._collection_name, None) | |||
| def text_exists(self, id: str) -> bool: | |||
| result = self._client.query(collection_name=self._collection_name, | |||
| filter=f'metadata["doc_id"] == "{id}"', | |||
| output_fields=["id"]) | |||
| return len(result) > 0 | |||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||
| # Set search parameters. | |||
| results = self._client.search(collection_name=self._collection_name, | |||
| data=[query_vector], | |||
| limit=kwargs.get('top_k', 4), | |||
| output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], | |||
| ) | |||
| # Organize results. | |||
| docs = [] | |||
| for result in results[0]: | |||
| metadata = result['entity'].get(Field.METADATA_KEY.value) | |||
| metadata['score'] = result['distance'] | |||
| score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 | |||
| if result['distance'] > score_threshold: | |||
| doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value), | |||
| metadata=metadata) | |||
| docs.append(doc) | |||
| return docs | |||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||
| # milvus/zilliz doesn't support bm25 search | |||
| return [] | |||
| def create_collection( | |||
| self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None | |||
| ) -> str: | |||
| from pymilvus import CollectionSchema, DataType, FieldSchema | |||
| from pymilvus.orm.types import infer_dtype_bydata | |||
| # Determine embedding dim | |||
| dim = len(embeddings[0]) | |||
| fields = [] | |||
| if metadatas: | |||
| fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) | |||
| # Create the text field | |||
| fields.append( | |||
| FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535) | |||
| ) | |||
| # Create the primary key field | |||
| fields.append( | |||
| FieldSchema( | |||
| Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True | |||
| ) | |||
| ) | |||
| # Create the vector field, supports binary or float vectors | |||
| fields.append( | |||
| FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim) | |||
| ) | |||
| # Create the schema for the collection | |||
| schema = CollectionSchema(fields) | |||
| for x in schema.fields: | |||
| self._fields.append(x.name) | |||
| # Since primary field is auto-id, no need to track it | |||
| self._fields.remove(Field.PRIMARY_KEY.value) | |||
| # Create the collection | |||
| collection_name = self._collection_name | |||
| self._client.create_collection_with_schema(collection_name=collection_name, | |||
| schema=schema, index_param=index_params, | |||
| consistency_level=self._consistency_level) | |||
| return collection_name | |||
| def _init_client(self, config) -> MilvusClient: | |||
| if config.secure: | |||
| uri = "https://" + str(config.host) + ":" + str(config.port) | |||
| else: | |||
| uri = "http://" + str(config.host) + ":" + str(config.port) | |||
| client = MilvusClient(uri=uri, user=config.user, password=config.password) | |||
| return client | |||
| @@ -0,0 +1,360 @@ | |||
| import os | |||
| import uuid | |||
| from collections.abc import Generator, Iterable, Sequence | |||
| from itertools import islice | |||
| from typing import TYPE_CHECKING, Any, Optional, Union, cast | |||
| import qdrant_client | |||
| from pydantic import BaseModel | |||
| from qdrant_client.http import models as rest | |||
| from qdrant_client.http.models import ( | |||
| FilterSelector, | |||
| HnswConfigDiff, | |||
| PayloadSchemaType, | |||
| TextIndexParams, | |||
| TextIndexType, | |||
| TokenizerType, | |||
| ) | |||
| from qdrant_client.local.qdrant_local import QdrantLocal | |||
| from core.rag.datasource.vdb.field import Field | |||
| from core.rag.datasource.vdb.vector_base import BaseVector | |||
| from core.rag.models.document import Document | |||
| if TYPE_CHECKING: | |||
| from qdrant_client import grpc # noqa | |||
| from qdrant_client.conversions import common_types | |||
| from qdrant_client.http import models as rest | |||
| DictFilter = dict[str, Union[str, int, bool, dict, list]] | |||
| MetadataFilter = Union[DictFilter, common_types.Filter] | |||
| class QdrantConfig(BaseModel): | |||
| endpoint: str | |||
| api_key: Optional[str] | |||
| timeout: float = 20 | |||
| root_path: Optional[str] | |||
| def to_qdrant_params(self): | |||
| if self.endpoint and self.endpoint.startswith('path:'): | |||
| path = self.endpoint.replace('path:', '') | |||
| if not os.path.isabs(path): | |||
| path = os.path.join(self.root_path, path) | |||
| return { | |||
| 'path': path | |||
| } | |||
| else: | |||
| return { | |||
| 'url': self.endpoint, | |||
| 'api_key': self.api_key, | |||
| 'timeout': self.timeout | |||
| } | |||
| class QdrantVector(BaseVector): | |||
| def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = 'Cosine'): | |||
| super().__init__(collection_name) | |||
| self._client_config = config | |||
| self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params()) | |||
| self._distance_func = distance_func.upper() | |||
| self._group_id = group_id | |||
| def get_type(self) -> str: | |||
| return 'qdrant' | |||
| def to_index_struct(self) -> dict: | |||
| return { | |||
| "type": self.get_type(), | |||
| "vector_store": {"class_prefix": self._collection_name} | |||
| } | |||
| def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | |||
| if texts: | |||
| # get embedding vector size | |||
| vector_size = len(embeddings[0]) | |||
| # get collection name | |||
| collection_name = self._collection_name | |||
| collection_name = collection_name or uuid.uuid4().hex | |||
| all_collection_name = [] | |||
| collections_response = self._client.get_collections() | |||
| collection_list = collections_response.collections | |||
| for collection in collection_list: | |||
| all_collection_name.append(collection.name) | |||
| if collection_name not in all_collection_name: | |||
| # create collection | |||
| self.create_collection(collection_name, vector_size) | |||
| self.add_texts(texts, embeddings, **kwargs) | |||
| def create_collection(self, collection_name: str, vector_size: int): | |||
| from qdrant_client.http import models as rest | |||
| vectors_config = rest.VectorParams( | |||
| size=vector_size, | |||
| distance=rest.Distance[self._distance_func], | |||
| ) | |||
| hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, | |||
| max_indexing_threads=0, on_disk=False) | |||
| self._client.recreate_collection( | |||
| collection_name=collection_name, | |||
| vectors_config=vectors_config, | |||
| hnsw_config=hnsw_config, | |||
| timeout=int(self._client_config.timeout), | |||
| ) | |||
| # create payload index | |||
| self._client.create_payload_index(collection_name, Field.GROUP_KEY.value, | |||
| field_schema=PayloadSchemaType.KEYWORD, | |||
| field_type=PayloadSchemaType.KEYWORD) | |||
| # creat full text index | |||
| text_index_params = TextIndexParams( | |||
| type=TextIndexType.TEXT, | |||
| tokenizer=TokenizerType.MULTILINGUAL, | |||
| min_token_len=2, | |||
| max_token_len=20, | |||
| lowercase=True | |||
| ) | |||
| self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value, | |||
| field_schema=text_index_params) | |||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | |||
| uuids = self._get_uuids(documents) | |||
| texts = [d.page_content for d in documents] | |||
| metadatas = [d.metadata for d in documents] | |||
| added_ids = [] | |||
| for batch_ids, points in self._generate_rest_batches( | |||
| texts, embeddings, metadatas, uuids, 64, self._group_id | |||
| ): | |||
| self._client.upsert( | |||
| collection_name=self._collection_name, points=points | |||
| ) | |||
| added_ids.extend(batch_ids) | |||
| return added_ids | |||
| def _generate_rest_batches( | |||
| self, | |||
| texts: Iterable[str], | |||
| embeddings: list[list[float]], | |||
| metadatas: Optional[list[dict]] = None, | |||
| ids: Optional[Sequence[str]] = None, | |||
| batch_size: int = 64, | |||
| group_id: Optional[str] = None, | |||
| ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]: | |||
| from qdrant_client.http import models as rest | |||
| texts_iterator = iter(texts) | |||
| embeddings_iterator = iter(embeddings) | |||
| metadatas_iterator = iter(metadatas or []) | |||
| ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)]) | |||
| while batch_texts := list(islice(texts_iterator, batch_size)): | |||
| # Take the corresponding metadata and id for each text in a batch | |||
| batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None | |||
| batch_ids = list(islice(ids_iterator, batch_size)) | |||
| # Generate the embeddings for all the texts in a batch | |||
| batch_embeddings = list(islice(embeddings_iterator, batch_size)) | |||
| points = [ | |||
| rest.PointStruct( | |||
| id=point_id, | |||
| vector=vector, | |||
| payload=payload, | |||
| ) | |||
| for point_id, vector, payload in zip( | |||
| batch_ids, | |||
| batch_embeddings, | |||
| self._build_payloads( | |||
| batch_texts, | |||
| batch_metadatas, | |||
| Field.CONTENT_KEY.value, | |||
| Field.METADATA_KEY.value, | |||
| group_id, | |||
| Field.GROUP_KEY.value, | |||
| ), | |||
| ) | |||
| ] | |||
| yield batch_ids, points | |||
| @classmethod | |||
| def _build_payloads( | |||
| cls, | |||
| texts: Iterable[str], | |||
| metadatas: Optional[list[dict]], | |||
| content_payload_key: str, | |||
| metadata_payload_key: str, | |||
| group_id: str, | |||
| group_payload_key: str | |||
| ) -> list[dict]: | |||
| payloads = [] | |||
| for i, text in enumerate(texts): | |||
| if text is None: | |||
| raise ValueError( | |||
| "At least one of the texts is None. Please remove it before " | |||
| "calling .from_texts or .add_texts on Qdrant instance." | |||
| ) | |||
| metadata = metadatas[i] if metadatas is not None else None | |||
| payloads.append( | |||
| { | |||
| content_payload_key: text, | |||
| metadata_payload_key: metadata, | |||
| group_payload_key: group_id | |||
| } | |||
| ) | |||
| return payloads | |||
| def delete_by_metadata_field(self, key: str, value: str): | |||
| from qdrant_client.http import models | |||
| filter = models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key=f"metadata.{key}", | |||
| match=models.MatchValue(value=value), | |||
| ), | |||
| ], | |||
| ) | |||
| self._reload_if_needed() | |||
| self._client.delete( | |||
| collection_name=self._collection_name, | |||
| points_selector=FilterSelector( | |||
| filter=filter | |||
| ), | |||
| ) | |||
| def delete(self): | |||
| from qdrant_client.http import models | |||
| filter = models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key="group_id", | |||
| match=models.MatchValue(value=self._group_id), | |||
| ), | |||
| ], | |||
| ) | |||
| self._client.delete( | |||
| collection_name=self._collection_name, | |||
| points_selector=FilterSelector( | |||
| filter=filter | |||
| ), | |||
| ) | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| from qdrant_client.http import models | |||
| for node_id in ids: | |||
| filter = models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key="metadata.doc_id", | |||
| match=models.MatchValue(value=node_id), | |||
| ), | |||
| ], | |||
| ) | |||
| self._client.delete( | |||
| collection_name=self._collection_name, | |||
| points_selector=FilterSelector( | |||
| filter=filter | |||
| ), | |||
| ) | |||
| def text_exists(self, id: str) -> bool: | |||
| response = self._client.retrieve( | |||
| collection_name=self._collection_name, | |||
| ids=[id] | |||
| ) | |||
| return len(response) > 0 | |||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||
| from qdrant_client.http import models | |||
| filter = models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key="group_id", | |||
| match=models.MatchValue(value=self._group_id), | |||
| ), | |||
| ], | |||
| ) | |||
| results = self._client.search( | |||
| collection_name=self._collection_name, | |||
| query_vector=query_vector, | |||
| query_filter=filter, | |||
| limit=kwargs.get("top_k", 4), | |||
| with_payload=True, | |||
| with_vectors=True, | |||
| score_threshold=kwargs.get("score_threshold", .0) | |||
| ) | |||
| docs = [] | |||
| for result in results: | |||
| metadata = result.payload.get(Field.METADATA_KEY.value) or {} | |||
| # duplicate check score threshold | |||
| score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 | |||
| if result.score > score_threshold: | |||
| metadata['score'] = result.score | |||
| doc = Document( | |||
| page_content=result.payload.get(Field.CONTENT_KEY.value), | |||
| metadata=metadata, | |||
| ) | |||
| docs.append(doc) | |||
| return docs | |||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||
| """Return docs most similar by bm25. | |||
| Returns: | |||
| List of documents most similar to the query text and distance for each. | |||
| """ | |||
| from qdrant_client.http import models | |||
| scroll_filter = models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key="group_id", | |||
| match=models.MatchValue(value=self._group_id), | |||
| ), | |||
| models.FieldCondition( | |||
| key="page_content", | |||
| match=models.MatchText(text=query), | |||
| ) | |||
| ] | |||
| ) | |||
| response = self._client.scroll( | |||
| collection_name=self._collection_name, | |||
| scroll_filter=scroll_filter, | |||
| limit=kwargs.get('top_k', 2), | |||
| with_payload=True, | |||
| with_vectors=True | |||
| ) | |||
| results = response[0] | |||
| documents = [] | |||
| for result in results: | |||
| if result: | |||
| documents.append(self._document_from_scored_point( | |||
| result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value | |||
| )) | |||
| return documents | |||
| def _reload_if_needed(self): | |||
| if isinstance(self._client, QdrantLocal): | |||
| self._client = cast(QdrantLocal, self._client) | |||
| self._client._load() | |||
| @classmethod | |||
| def _document_from_scored_point( | |||
| cls, | |||
| scored_point: Any, | |||
| content_payload_key: str, | |||
| metadata_payload_key: str, | |||
| ) -> Document: | |||
| return Document( | |||
| page_content=scored_point.payload.get(content_payload_key), | |||
| metadata=scored_point.payload.get(metadata_payload_key) or {}, | |||
| ) | |||
| @@ -0,0 +1,62 @@ | |||
| from __future__ import annotations | |||
| from abc import ABC, abstractmethod | |||
| from typing import Any | |||
| from core.rag.models.document import Document | |||
| class BaseVector(ABC): | |||
| def __init__(self, collection_name: str): | |||
| self._collection_name = collection_name | |||
| @abstractmethod | |||
| def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def text_exists(self, id: str) -> bool: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def delete_by_metadata_field(self, key: str, value: str) -> None: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def search_by_vector( | |||
| self, | |||
| query_vector: list[float], | |||
| **kwargs: Any | |||
| ) -> list[Document]: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def search_by_full_text( | |||
| self, query: str, | |||
| **kwargs: Any | |||
| ) -> list[Document]: | |||
| raise NotImplementedError | |||
| def delete(self) -> None: | |||
| raise NotImplementedError | |||
| def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: | |||
| for text in texts: | |||
| doc_id = text.metadata['doc_id'] | |||
| exists_duplicate_node = self.text_exists(doc_id) | |||
| if exists_duplicate_node: | |||
| texts.remove(text) | |||
| return texts | |||
| def _get_uuids(self, texts: list[Document]) -> list[str]: | |||
| return [text.metadata['doc_id'] for text in texts] | |||
| @@ -0,0 +1,171 @@ | |||
| from typing import Any, cast | |||
| from flask import current_app | |||
| from core.embedding.cached_embedding import CacheEmbedding | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.rag.datasource.entity.embedding import Embeddings | |||
| from core.rag.datasource.vdb.vector_base import BaseVector | |||
| from core.rag.models.document import Document | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, DatasetCollectionBinding | |||
| class Vector: | |||
| def __init__(self, dataset: Dataset, attributes: list = None): | |||
| if attributes is None: | |||
| attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] | |||
| self._dataset = dataset | |||
| self._embeddings = self._get_embeddings() | |||
| self._attributes = attributes | |||
| self._vector_processor = self._init_vector() | |||
| def _init_vector(self) -> BaseVector: | |||
| config = cast(dict, current_app.config) | |||
| vector_type = config.get('VECTOR_STORE') | |||
| if self._dataset.index_struct_dict: | |||
| vector_type = self._dataset.index_struct_dict['type'] | |||
| if not vector_type: | |||
| raise ValueError("Vector store must be specified.") | |||
| if vector_type == "weaviate": | |||
| from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector | |||
| if self._dataset.index_struct_dict: | |||
| class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] | |||
| collection_name = class_prefix | |||
| else: | |||
| dataset_id = self._dataset.id | |||
| collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' | |||
| return WeaviateVector( | |||
| collection_name=collection_name, | |||
| config=WeaviateConfig( | |||
| endpoint=config.get('WEAVIATE_ENDPOINT'), | |||
| api_key=config.get('WEAVIATE_API_KEY'), | |||
| batch_size=int(config.get('WEAVIATE_BATCH_SIZE')) | |||
| ), | |||
| attributes=self._attributes | |||
| ) | |||
| elif vector_type == "qdrant": | |||
| from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector | |||
| if self._dataset.collection_binding_id: | |||
| dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ | |||
| filter(DatasetCollectionBinding.id == self._dataset.collection_binding_id). \ | |||
| one_or_none() | |||
| if dataset_collection_binding: | |||
| collection_name = dataset_collection_binding.collection_name | |||
| else: | |||
| raise ValueError('Dataset Collection Bindings is not exist!') | |||
| else: | |||
| if self._dataset.index_struct_dict: | |||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||
| collection_name = class_prefix | |||
| else: | |||
| dataset_id = self._dataset.id | |||
| collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' | |||
| return QdrantVector( | |||
| collection_name=collection_name, | |||
| group_id=self._dataset.id, | |||
| config=QdrantConfig( | |||
| endpoint=config.get('QDRANT_URL'), | |||
| api_key=config.get('QDRANT_API_KEY'), | |||
| root_path=current_app.root_path, | |||
| timeout=config.get('QDRANT_CLIENT_TIMEOUT') | |||
| ) | |||
| ) | |||
| elif vector_type == "milvus": | |||
| from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector | |||
| if self._dataset.index_struct_dict: | |||
| class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] | |||
| collection_name = class_prefix | |||
| else: | |||
| dataset_id = self._dataset.id | |||
| collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' | |||
| return MilvusVector( | |||
| collection_name=collection_name, | |||
| config=MilvusConfig( | |||
| host=config.get('MILVUS_HOST'), | |||
| port=config.get('MILVUS_PORT'), | |||
| user=config.get('MILVUS_USER'), | |||
| password=config.get('MILVUS_PASSWORD'), | |||
| secure=config.get('MILVUS_SECURE'), | |||
| ) | |||
| ) | |||
| else: | |||
| raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") | |||
| def create(self, texts: list = None, **kwargs): | |||
| if texts: | |||
| embeddings = self._embeddings.embed_documents([document.page_content for document in texts]) | |||
| self._vector_processor.create( | |||
| texts=texts, | |||
| embeddings=embeddings, | |||
| **kwargs | |||
| ) | |||
| def add_texts(self, documents: list[Document], **kwargs): | |||
| if kwargs.get('duplicate_check', False): | |||
| documents = self._filter_duplicate_texts(documents) | |||
| embeddings = self._embeddings.embed_documents([document.page_content for document in documents]) | |||
| self._vector_processor.add_texts( | |||
| documents=documents, | |||
| embeddings=embeddings, | |||
| **kwargs | |||
| ) | |||
| def text_exists(self, id: str) -> bool: | |||
| return self._vector_processor.text_exists(id) | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| self._vector_processor.delete_by_ids(ids) | |||
| def delete_by_metadata_field(self, key: str, value: str) -> None: | |||
| self._vector_processor.delete_by_metadata_field(key, value) | |||
| def search_by_vector( | |||
| self, query: str, | |||
| **kwargs: Any | |||
| ) -> list[Document]: | |||
| query_vector = self._embeddings.embed_query(query) | |||
| return self._vector_processor.search_by_vector(query_vector, **kwargs) | |||
| def search_by_full_text( | |||
| self, query: str, | |||
| **kwargs: Any | |||
| ) -> list[Document]: | |||
| return self._vector_processor.search_by_full_text(query, **kwargs) | |||
| def delete(self) -> None: | |||
| self._vector_processor.delete() | |||
| def _get_embeddings(self) -> Embeddings: | |||
| model_manager = ModelManager() | |||
| embedding_model = model_manager.get_model_instance( | |||
| tenant_id=self._dataset.tenant_id, | |||
| provider=self._dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=self._dataset.embedding_model | |||
| ) | |||
| return CacheEmbedding(embedding_model) | |||
| def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: | |||
| for text in texts: | |||
| doc_id = text.metadata['doc_id'] | |||
| exists_duplicate_node = self.text_exists(doc_id) | |||
| if exists_duplicate_node: | |||
| texts.remove(text) | |||
| return texts | |||
| def __getattr__(self, name): | |||
| if self._vector_processor is not None: | |||
| method = getattr(self._vector_processor, name) | |||
| if callable(method): | |||
| return method | |||
| raise AttributeError(f"'vector_processor' object has no attribute '{name}'") | |||
| @@ -0,0 +1,235 @@ | |||
| import datetime | |||
| from typing import Any, Optional | |||
| import requests | |||
| import weaviate | |||
| from pydantic import BaseModel, root_validator | |||
| from core.rag.datasource.vdb.field import Field | |||
| from core.rag.datasource.vdb.vector_base import BaseVector | |||
| from core.rag.models.document import Document | |||
| from models.dataset import Dataset | |||
| class WeaviateConfig(BaseModel): | |||
| endpoint: str | |||
| api_key: Optional[str] | |||
| batch_size: int = 100 | |||
| @root_validator() | |||
| def validate_config(cls, values: dict) -> dict: | |||
| if not values['endpoint']: | |||
| raise ValueError("config WEAVIATE_ENDPOINT is required") | |||
| return values | |||
| class WeaviateVector(BaseVector): | |||
| def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list): | |||
| super().__init__(collection_name) | |||
| self._client = self._init_client(config) | |||
| self._attributes = attributes | |||
| def _init_client(self, config: WeaviateConfig) -> weaviate.Client: | |||
| auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key) | |||
| weaviate.connect.connection.has_grpc = False | |||
| try: | |||
| client = weaviate.Client( | |||
| url=config.endpoint, | |||
| auth_client_secret=auth_config, | |||
| timeout_config=(5, 60), | |||
| startup_period=None | |||
| ) | |||
| except requests.exceptions.ConnectionError: | |||
| raise ConnectionError("Vector database connection error") | |||
| client.batch.configure( | |||
| # `batch_size` takes an `int` value to enable auto-batching | |||
| # (`None` is used for manual batching) | |||
| batch_size=config.batch_size, | |||
| # dynamically update the `batch_size` based on import speed | |||
| dynamic=True, | |||
| # `timeout_retries` takes an `int` value to retry on time outs | |||
| timeout_retries=3, | |||
| ) | |||
| return client | |||
| def get_type(self) -> str: | |||
| return 'weaviate' | |||
| def get_collection_name(self, dataset: Dataset) -> str: | |||
| if dataset.index_struct_dict: | |||
| class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] | |||
| if not class_prefix.endswith('_Node'): | |||
| # original class_prefix | |||
| class_prefix += '_Node' | |||
| return class_prefix | |||
| dataset_id = dataset.id | |||
| return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' | |||
| def to_index_struct(self) -> dict: | |||
| return { | |||
| "type": self.get_type(), | |||
| "vector_store": {"class_prefix": self._collection_name} | |||
| } | |||
| def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | |||
| schema = self._default_schema(self._collection_name) | |||
| # check whether the index already exists | |||
| if not self._client.schema.contains(schema): | |||
| # create collection | |||
| self._client.schema.create_class(schema) | |||
| # create vector | |||
| self.add_texts(texts, embeddings) | |||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | |||
| uuids = self._get_uuids(documents) | |||
| texts = [d.page_content for d in documents] | |||
| metadatas = [d.metadata for d in documents] | |||
| ids = [] | |||
| with self._client.batch as batch: | |||
| for i, text in enumerate(texts): | |||
| data_properties = {Field.TEXT_KEY.value: text} | |||
| if metadatas is not None: | |||
| for key, val in metadatas[i].items(): | |||
| data_properties[key] = self._json_serializable(val) | |||
| batch.add_data_object( | |||
| data_object=data_properties, | |||
| class_name=self._collection_name, | |||
| uuid=uuids[i], | |||
| vector=embeddings[i] if embeddings else None, | |||
| ) | |||
| ids.append(uuids[i]) | |||
| return ids | |||
| def delete_by_metadata_field(self, key: str, value: str): | |||
| where_filter = { | |||
| "operator": "Equal", | |||
| "path": [key], | |||
| "valueText": value | |||
| } | |||
| self._client.batch.delete_objects( | |||
| class_name=self._collection_name, | |||
| where=where_filter, | |||
| output='minimal' | |||
| ) | |||
| def delete(self): | |||
| self._client.schema.delete_class(self._collection_name) | |||
| def text_exists(self, id: str) -> bool: | |||
| collection_name = self._collection_name | |||
| result = self._client.query.get(collection_name).with_additional(["id"]).with_where({ | |||
| "path": ["doc_id"], | |||
| "operator": "Equal", | |||
| "valueText": id, | |||
| }).with_limit(1).do() | |||
| if "errors" in result: | |||
| raise ValueError(f"Error during query: {result['errors']}") | |||
| entries = result["data"]["Get"][collection_name] | |||
| if len(entries) == 0: | |||
| return False | |||
| return True | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| self._client.data_object.delete( | |||
| ids, | |||
| class_name=self._collection_name | |||
| ) | |||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||
| """Look up similar documents by embedding vector in Weaviate.""" | |||
| collection_name = self._collection_name | |||
| properties = self._attributes | |||
| properties.append(Field.TEXT_KEY.value) | |||
| query_obj = self._client.query.get(collection_name, properties) | |||
| vector = {"vector": query_vector} | |||
| if kwargs.get("where_filter"): | |||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||
| result = ( | |||
| query_obj.with_near_vector(vector) | |||
| .with_limit(kwargs.get("top_k", 4)) | |||
| .with_additional(["vector", "distance"]) | |||
| .do() | |||
| ) | |||
| if "errors" in result: | |||
| raise ValueError(f"Error during query: {result['errors']}") | |||
| docs_and_scores = [] | |||
| for res in result["data"]["Get"][collection_name]: | |||
| text = res.pop(Field.TEXT_KEY.value) | |||
| score = 1 - res["_additional"]["distance"] | |||
| docs_and_scores.append((Document(page_content=text, metadata=res), score)) | |||
| docs = [] | |||
| for doc, score in docs_and_scores: | |||
| score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 | |||
| # check score threshold | |||
| if score > score_threshold: | |||
| doc.metadata['score'] = score | |||
| docs.append(doc) | |||
| return docs | |||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||
| """Return docs using BM25F. | |||
| Args: | |||
| query: Text to look up documents similar to. | |||
| k: Number of Documents to return. Defaults to 4. | |||
| Returns: | |||
| List of Documents most similar to the query. | |||
| """ | |||
| collection_name = self._collection_name | |||
| content: dict[str, Any] = {"concepts": [query]} | |||
| properties = self._attributes | |||
| properties.append(Field.TEXT_KEY.value) | |||
| if kwargs.get("search_distance"): | |||
| content["certainty"] = kwargs.get("search_distance") | |||
| query_obj = self._client.query.get(collection_name, properties) | |||
| if kwargs.get("where_filter"): | |||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||
| if kwargs.get("additional"): | |||
| query_obj = query_obj.with_additional(kwargs.get("additional")) | |||
| properties = ['text'] | |||
| result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do() | |||
| if "errors" in result: | |||
| raise ValueError(f"Error during query: {result['errors']}") | |||
| docs = [] | |||
| for res in result["data"]["Get"][collection_name]: | |||
| text = res.pop(Field.TEXT_KEY.value) | |||
| docs.append(Document(page_content=text, metadata=res)) | |||
| return docs | |||
| def _default_schema(self, index_name: str) -> dict: | |||
| return { | |||
| "class": index_name, | |||
| "properties": [ | |||
| { | |||
| "name": "text", | |||
| "dataType": ["text"], | |||
| } | |||
| ], | |||
| } | |||
| def _json_serializable(self, value: Any) -> Any: | |||
| if isinstance(value, datetime.datetime): | |||
| return value.isoformat() | |||
| return value | |||
| @@ -0,0 +1,166 @@ | |||
| """Schema for Blobs and Blob Loaders. | |||
| The goal is to facilitate decoupling of content loading from content parsing code. | |||
| In addition, content loading code should provide a lazy loading interface by default. | |||
| """ | |||
| from __future__ import annotations | |||
| import contextlib | |||
| import mimetypes | |||
| from abc import ABC, abstractmethod | |||
| from collections.abc import Generator, Iterable, Mapping | |||
| from io import BufferedReader, BytesIO | |||
| from pathlib import PurePath | |||
| from typing import Any, Optional, Union | |||
| from pydantic import BaseModel, root_validator | |||
| PathLike = Union[str, PurePath] | |||
| class Blob(BaseModel): | |||
| """A blob is used to represent raw data by either reference or value. | |||
| Provides an interface to materialize the blob in different representations, and | |||
| help to decouple the development of data loaders from the downstream parsing of | |||
| the raw data. | |||
| Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob | |||
| """ | |||
| data: Union[bytes, str, None] # Raw data | |||
| mimetype: Optional[str] = None # Not to be confused with a file extension | |||
| encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string | |||
| # Location where the original content was found | |||
| # Represent location on the local file system | |||
| # Useful for situations where downstream code assumes it must work with file paths | |||
| # rather than in-memory content. | |||
| path: Optional[PathLike] = None | |||
| class Config: | |||
| arbitrary_types_allowed = True | |||
| frozen = True | |||
| @property | |||
| def source(self) -> Optional[str]: | |||
| """The source location of the blob as string if known otherwise none.""" | |||
| return str(self.path) if self.path else None | |||
| @root_validator(pre=True) | |||
| def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]: | |||
| """Verify that either data or path is provided.""" | |||
| if "data" not in values and "path" not in values: | |||
| raise ValueError("Either data or path must be provided") | |||
| return values | |||
| def as_string(self) -> str: | |||
| """Read data as a string.""" | |||
| if self.data is None and self.path: | |||
| with open(str(self.path), encoding=self.encoding) as f: | |||
| return f.read() | |||
| elif isinstance(self.data, bytes): | |||
| return self.data.decode(self.encoding) | |||
| elif isinstance(self.data, str): | |||
| return self.data | |||
| else: | |||
| raise ValueError(f"Unable to get string for blob {self}") | |||
| def as_bytes(self) -> bytes: | |||
| """Read data as bytes.""" | |||
| if isinstance(self.data, bytes): | |||
| return self.data | |||
| elif isinstance(self.data, str): | |||
| return self.data.encode(self.encoding) | |||
| elif self.data is None and self.path: | |||
| with open(str(self.path), "rb") as f: | |||
| return f.read() | |||
| else: | |||
| raise ValueError(f"Unable to get bytes for blob {self}") | |||
| @contextlib.contextmanager | |||
| def as_bytes_io(self) -> Generator[Union[BytesIO, BufferedReader], None, None]: | |||
| """Read data as a byte stream.""" | |||
| if isinstance(self.data, bytes): | |||
| yield BytesIO(self.data) | |||
| elif self.data is None and self.path: | |||
| with open(str(self.path), "rb") as f: | |||
| yield f | |||
| else: | |||
| raise NotImplementedError(f"Unable to convert blob {self}") | |||
| @classmethod | |||
| def from_path( | |||
| cls, | |||
| path: PathLike, | |||
| *, | |||
| encoding: str = "utf-8", | |||
| mime_type: Optional[str] = None, | |||
| guess_type: bool = True, | |||
| ) -> Blob: | |||
| """Load the blob from a path like object. | |||
| Args: | |||
| path: path like object to file to be read | |||
| encoding: Encoding to use if decoding the bytes into a string | |||
| mime_type: if provided, will be set as the mime-type of the data | |||
| guess_type: If True, the mimetype will be guessed from the file extension, | |||
| if a mime-type was not provided | |||
| Returns: | |||
| Blob instance | |||
| """ | |||
| if mime_type is None and guess_type: | |||
| _mimetype = mimetypes.guess_type(path)[0] if guess_type else None | |||
| else: | |||
| _mimetype = mime_type | |||
| # We do not load the data immediately, instead we treat the blob as a | |||
| # reference to the underlying data. | |||
| return cls(data=None, mimetype=_mimetype, encoding=encoding, path=path) | |||
| @classmethod | |||
| def from_data( | |||
| cls, | |||
| data: Union[str, bytes], | |||
| *, | |||
| encoding: str = "utf-8", | |||
| mime_type: Optional[str] = None, | |||
| path: Optional[str] = None, | |||
| ) -> Blob: | |||
| """Initialize the blob from in-memory data. | |||
| Args: | |||
| data: the in-memory data associated with the blob | |||
| encoding: Encoding to use if decoding the bytes into a string | |||
| mime_type: if provided, will be set as the mime-type of the data | |||
| path: if provided, will be set as the source from which the data came | |||
| Returns: | |||
| Blob instance | |||
| """ | |||
| return cls(data=data, mimetype=mime_type, encoding=encoding, path=path) | |||
| def __repr__(self) -> str: | |||
| """Define the blob representation.""" | |||
| str_repr = f"Blob {id(self)}" | |||
| if self.source: | |||
| str_repr += f" {self.source}" | |||
| return str_repr | |||
| class BlobLoader(ABC): | |||
| """Abstract interface for blob loaders implementation. | |||
| Implementer should be able to load raw content from a datasource system according | |||
| to some criteria and return the raw content lazily as a stream of blobs. | |||
| """ | |||
| @abstractmethod | |||
| def yield_blobs( | |||
| self, | |||
| ) -> Iterable[Blob]: | |||
| """A lazy loader for raw data represented by LangChain's Blob object. | |||
| Returns: | |||
| A generator over blobs | |||
| """ | |||
| @@ -0,0 +1,71 @@ | |||
| """Abstract interface for document loader implementations.""" | |||
| import csv | |||
| from typing import Optional | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| from core.rag.models.document import Document | |||
| class CSVExtractor(BaseExtractor): | |||
| """Load CSV files. | |||
| Args: | |||
| file_path: Path to the file to load. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| file_path: str, | |||
| encoding: Optional[str] = None, | |||
| autodetect_encoding: bool = False, | |||
| source_column: Optional[str] = None, | |||
| csv_args: Optional[dict] = None, | |||
| ): | |||
| """Initialize with file path.""" | |||
| self._file_path = file_path | |||
| self._encoding = encoding | |||
| self._autodetect_encoding = autodetect_encoding | |||
| self.source_column = source_column | |||
| self.csv_args = csv_args or {} | |||
| def extract(self) -> list[Document]: | |||
| """Load data into document objects.""" | |||
| try: | |||
| with open(self._file_path, newline="", encoding=self._encoding) as csvfile: | |||
| docs = self._read_from_file(csvfile) | |||
| except UnicodeDecodeError as e: | |||
| if self._autodetect_encoding: | |||
| detected_encodings = detect_filze_encodings(self._file_path) | |||
| for encoding in detected_encodings: | |||
| try: | |||
| with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile: | |||
| docs = self._read_from_file(csvfile) | |||
| break | |||
| except UnicodeDecodeError: | |||
| continue | |||
| else: | |||
| raise RuntimeError(f"Error loading {self._file_path}") from e | |||
| return docs | |||
| def _read_from_file(self, csvfile) -> list[Document]: | |||
| docs = [] | |||
| csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore | |||
| for i, row in enumerate(csv_reader): | |||
| content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items()) | |||
| try: | |||
| source = ( | |||
| row[self.source_column] | |||
| if self.source_column is not None | |||
| else '' | |||
| ) | |||
| except KeyError: | |||
| raise ValueError( | |||
| f"Source column '{self.source_column}' not found in CSV file." | |||
| ) | |||
| metadata = {"source": source, "row": i} | |||
| doc = Document(page_content=content, metadata=metadata) | |||
| docs.append(doc) | |||
| return docs | |||
| @@ -0,0 +1,6 @@ | |||
| from enum import Enum | |||
| class DatasourceType(Enum): | |||
| FILE = "upload_file" | |||
| NOTION = "notion_import" | |||
| @@ -0,0 +1,36 @@ | |||
| from pydantic import BaseModel | |||
| from models.dataset import Document | |||
| from models.model import UploadFile | |||
| class NotionInfo(BaseModel): | |||
| """ | |||
| Notion import info. | |||
| """ | |||
| notion_workspace_id: str | |||
| notion_obj_id: str | |||
| notion_page_type: str | |||
| document: Document = None | |||
| class Config: | |||
| arbitrary_types_allowed = True | |||
| def __init__(self, **data) -> None: | |||
| super().__init__(**data) | |||
| class ExtractSetting(BaseModel): | |||
| """ | |||
| Model class for provider response. | |||
| """ | |||
| datasource_type: str | |||
| upload_file: UploadFile = None | |||
| notion_info: NotionInfo = None | |||
| document_model: str = None | |||
| class Config: | |||
| arbitrary_types_allowed = True | |||
| def __init__(self, **data) -> None: | |||
| super().__init__(**data) | |||
| @@ -1,14 +1,14 @@ | |||
| import logging | |||
| """Abstract interface for document loader implementations.""" | |||
| from typing import Optional | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.schema import Document | |||
| from openpyxl.reader.excel import load_workbook | |||
| logger = logging.getLogger(__name__) | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| from core.rag.models.document import Document | |||
| class ExcelLoader(BaseLoader): | |||
| """Load xlxs files. | |||
| class ExcelExtractor(BaseExtractor): | |||
| """Load Excel files. | |||
| Args: | |||
| @@ -16,13 +16,18 @@ class ExcelLoader(BaseLoader): | |||
| """ | |||
| def __init__( | |||
| self, | |||
| file_path: str | |||
| self, | |||
| file_path: str, | |||
| encoding: Optional[str] = None, | |||
| autodetect_encoding: bool = False | |||
| ): | |||
| """Initialize with file path.""" | |||
| self._file_path = file_path | |||
| self._encoding = encoding | |||
| self._autodetect_encoding = autodetect_encoding | |||
| def load(self) -> list[Document]: | |||
| def extract(self) -> list[Document]: | |||
| """Load from file path.""" | |||
| data = [] | |||
| keys = [] | |||
| wb = load_workbook(filename=self._file_path, read_only=True) | |||
| @@ -0,0 +1,139 @@ | |||
| import tempfile | |||
| from pathlib import Path | |||
| from typing import Union | |||
| import requests | |||
| from flask import current_app | |||
| from core.rag.extractor.csv_extractor import CSVExtractor | |||
| from core.rag.extractor.entity.datasource_type import DatasourceType | |||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||
| from core.rag.extractor.excel_extractor import ExcelExtractor | |||
| from core.rag.extractor.html_extractor import HtmlExtractor | |||
| from core.rag.extractor.markdown_extractor import MarkdownExtractor | |||
| from core.rag.extractor.notion_extractor import NotionExtractor | |||
| from core.rag.extractor.pdf_extractor import PdfExtractor | |||
| from core.rag.extractor.text_extractor import TextExtractor | |||
| from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor | |||
| from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor | |||
| from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor | |||
| from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor | |||
| from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor | |||
| from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor | |||
| from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor | |||
| from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor | |||
| from core.rag.extractor.word_extractor import WordExtractor | |||
| from core.rag.models.document import Document | |||
| from extensions.ext_storage import storage | |||
| from models.model import UploadFile | |||
| SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain'] | |||
| USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" | |||
| class ExtractProcessor: | |||
| @classmethod | |||
| def load_from_upload_file(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) \ | |||
| -> Union[list[Document], str]: | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="upload_file", | |||
| upload_file=upload_file, | |||
| document_model='text_model' | |||
| ) | |||
| if return_text: | |||
| delimiter = '\n' | |||
| return delimiter.join([document.page_content for document in cls.extract(extract_setting, is_automatic)]) | |||
| else: | |||
| return cls.extract(extract_setting, is_automatic) | |||
| @classmethod | |||
| def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]: | |||
| response = requests.get(url, headers={ | |||
| "User-Agent": USER_AGENT | |||
| }) | |||
| with tempfile.TemporaryDirectory() as temp_dir: | |||
| suffix = Path(url).suffix | |||
| file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" | |||
| with open(file_path, 'wb') as file: | |||
| file.write(response.content) | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="upload_file", | |||
| document_model='text_model' | |||
| ) | |||
| if return_text: | |||
| delimiter = '\n' | |||
| return delimiter.join([document.page_content for document in cls.extract( | |||
| extract_setting=extract_setting, file_path=file_path)]) | |||
| else: | |||
| return cls.extract(extract_setting=extract_setting, file_path=file_path) | |||
| @classmethod | |||
| def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False, | |||
| file_path: str = None) -> list[Document]: | |||
| if extract_setting.datasource_type == DatasourceType.FILE.value: | |||
| with tempfile.TemporaryDirectory() as temp_dir: | |||
| if not file_path: | |||
| upload_file: UploadFile = extract_setting.upload_file | |||
| suffix = Path(upload_file.key).suffix | |||
| file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" | |||
| storage.download(upload_file.key, file_path) | |||
| input_file = Path(file_path) | |||
| file_extension = input_file.suffix.lower() | |||
| etl_type = current_app.config['ETL_TYPE'] | |||
| unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL'] | |||
| if etl_type == 'Unstructured': | |||
| if file_extension == '.xlsx': | |||
| extractor = ExcelExtractor(file_path) | |||
| elif file_extension == '.pdf': | |||
| extractor = PdfExtractor(file_path) | |||
| elif file_extension in ['.md', '.markdown']: | |||
| extractor = UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic \ | |||
| else MarkdownExtractor(file_path, autodetect_encoding=True) | |||
| elif file_extension in ['.htm', '.html']: | |||
| extractor = HtmlExtractor(file_path) | |||
| elif file_extension in ['.docx']: | |||
| extractor = UnstructuredWordExtractor(file_path, unstructured_api_url) | |||
| elif file_extension == '.csv': | |||
| extractor = CSVExtractor(file_path, autodetect_encoding=True) | |||
| elif file_extension == '.msg': | |||
| extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url) | |||
| elif file_extension == '.eml': | |||
| extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url) | |||
| elif file_extension == '.ppt': | |||
| extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url) | |||
| elif file_extension == '.pptx': | |||
| extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url) | |||
| elif file_extension == '.xml': | |||
| extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url) | |||
| else: | |||
| # txt | |||
| extractor = UnstructuredTextExtractor(file_path, unstructured_api_url) if is_automatic \ | |||
| else TextExtractor(file_path, autodetect_encoding=True) | |||
| else: | |||
| if file_extension == '.xlsx': | |||
| extractor = ExcelExtractor(file_path) | |||
| elif file_extension == '.pdf': | |||
| extractor = PdfExtractor(file_path) | |||
| elif file_extension in ['.md', '.markdown']: | |||
| extractor = MarkdownExtractor(file_path, autodetect_encoding=True) | |||
| elif file_extension in ['.htm', '.html']: | |||
| extractor = HtmlExtractor(file_path) | |||
| elif file_extension in ['.docx']: | |||
| extractor = WordExtractor(file_path) | |||
| elif file_extension == '.csv': | |||
| extractor = CSVExtractor(file_path, autodetect_encoding=True) | |||
| else: | |||
| # txt | |||
| extractor = TextExtractor(file_path, autodetect_encoding=True) | |||
| return extractor.extract() | |||
| elif extract_setting.datasource_type == DatasourceType.NOTION.value: | |||
| extractor = NotionExtractor( | |||
| notion_workspace_id=extract_setting.notion_info.notion_workspace_id, | |||
| notion_obj_id=extract_setting.notion_info.notion_obj_id, | |||
| notion_page_type=extract_setting.notion_info.notion_page_type, | |||
| document_model=extract_setting.notion_info.document | |||
| ) | |||
| return extractor.extract() | |||
| else: | |||
| raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}") | |||
| @@ -0,0 +1,12 @@ | |||
| """Abstract interface for document loader implementations.""" | |||
| from abc import ABC, abstractmethod | |||
| class BaseExtractor(ABC): | |||
| """Interface for extract files. | |||
| """ | |||
| @abstractmethod | |||
| def extract(self): | |||
| raise NotImplementedError | |||
| @@ -0,0 +1,46 @@ | |||
| """Document loader helpers.""" | |||
| import concurrent.futures | |||
| from typing import NamedTuple, Optional, cast | |||
| class FileEncoding(NamedTuple): | |||
| """A file encoding as the NamedTuple.""" | |||
| encoding: Optional[str] | |||
| """The encoding of the file.""" | |||
| confidence: float | |||
| """The confidence of the encoding.""" | |||
| language: Optional[str] | |||
| """The language of the file.""" | |||
| def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding]: | |||
| """Try to detect the file encoding. | |||
| Returns a list of `FileEncoding` tuples with the detected encodings ordered | |||
| by confidence. | |||
| Args: | |||
| file_path: The path to the file to detect the encoding for. | |||
| timeout: The timeout in seconds for the encoding detection. | |||
| """ | |||
| import chardet | |||
| def read_and_detect(file_path: str) -> list[dict]: | |||
| with open(file_path, "rb") as f: | |||
| rawdata = f.read() | |||
| return cast(list[dict], chardet.detect_all(rawdata)) | |||
| with concurrent.futures.ThreadPoolExecutor() as executor: | |||
| future = executor.submit(read_and_detect, file_path) | |||
| try: | |||
| encodings = future.result(timeout=timeout) | |||
| except concurrent.futures.TimeoutError: | |||
| raise TimeoutError( | |||
| f"Timeout reached while detecting encoding for {file_path}" | |||
| ) | |||
| if all(encoding["encoding"] is None for encoding in encodings): | |||
| raise RuntimeError(f"Could not detect encoding for {file_path}") | |||
| return [FileEncoding(**enc) for enc in encodings if enc["encoding"] is not None] | |||
| @@ -1,51 +1,55 @@ | |||
| import csv | |||
| import logging | |||
| """Abstract interface for document loader implementations.""" | |||
| from typing import Optional | |||
| from langchain.document_loaders import CSVLoader as LCCSVLoader | |||
| from langchain.document_loaders.helpers import detect_file_encodings | |||
| from langchain.schema import Document | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| from core.rag.extractor.helpers import detect_file_encodings | |||
| from core.rag.models.document import Document | |||
| logger = logging.getLogger(__name__) | |||
| class HtmlExtractor(BaseExtractor): | |||
| """Load html files. | |||
| Args: | |||
| file_path: Path to the file to load. | |||
| """ | |||
| class CSVLoader(LCCSVLoader): | |||
| def __init__( | |||
| self, | |||
| file_path: str, | |||
| encoding: Optional[str] = None, | |||
| autodetect_encoding: bool = False, | |||
| source_column: Optional[str] = None, | |||
| csv_args: Optional[dict] = None, | |||
| encoding: Optional[str] = None, | |||
| autodetect_encoding: bool = True, | |||
| ): | |||
| self.file_path = file_path | |||
| """Initialize with file path.""" | |||
| self._file_path = file_path | |||
| self._encoding = encoding | |||
| self._autodetect_encoding = autodetect_encoding | |||
| self.source_column = source_column | |||
| self.encoding = encoding | |||
| self.csv_args = csv_args or {} | |||
| self.autodetect_encoding = autodetect_encoding | |||
| def load(self) -> list[Document]: | |||
| def extract(self) -> list[Document]: | |||
| """Load data into document objects.""" | |||
| try: | |||
| with open(self.file_path, newline="", encoding=self.encoding) as csvfile: | |||
| with open(self._file_path, newline="", encoding=self._encoding) as csvfile: | |||
| docs = self._read_from_file(csvfile) | |||
| except UnicodeDecodeError as e: | |||
| if self.autodetect_encoding: | |||
| detected_encodings = detect_file_encodings(self.file_path) | |||
| if self._autodetect_encoding: | |||
| detected_encodings = detect_file_encodings(self._file_path) | |||
| for encoding in detected_encodings: | |||
| logger.debug("Trying encoding: ", encoding.encoding) | |||
| try: | |||
| with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile: | |||
| with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile: | |||
| docs = self._read_from_file(csvfile) | |||
| break | |||
| except UnicodeDecodeError: | |||
| continue | |||
| else: | |||
| raise RuntimeError(f"Error loading {self.file_path}") from e | |||
| raise RuntimeError(f"Error loading {self._file_path}") from e | |||
| return docs | |||
| def _read_from_file(self, csvfile): | |||
| def _read_from_file(self, csvfile) -> list[Document]: | |||
| docs = [] | |||
| csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore | |||
| for i, row in enumerate(csv_reader): | |||
| @@ -1,39 +1,27 @@ | |||
| import logging | |||
| """Abstract interface for document loader implementations.""" | |||
| import re | |||
| from typing import Optional, cast | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.document_loaders.helpers import detect_file_encodings | |||
| from langchain.schema import Document | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| from core.rag.extractor.helpers import detect_file_encodings | |||
| from core.rag.models.document import Document | |||
| logger = logging.getLogger(__name__) | |||
| class MarkdownLoader(BaseLoader): | |||
| """Load md files. | |||
| class MarkdownExtractor(BaseExtractor): | |||
| """Load Markdown files. | |||
| Args: | |||
| file_path: Path to the file to load. | |||
| remove_hyperlinks: Whether to remove hyperlinks from the text. | |||
| remove_images: Whether to remove images from the text. | |||
| encoding: File encoding to use. If `None`, the file will be loaded | |||
| with the default system encoding. | |||
| autodetect_encoding: Whether to try to autodetect the file encoding | |||
| if the specified encoding fails. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| file_path: str, | |||
| remove_hyperlinks: bool = True, | |||
| remove_images: bool = True, | |||
| encoding: Optional[str] = None, | |||
| autodetect_encoding: bool = True, | |||
| self, | |||
| file_path: str, | |||
| remove_hyperlinks: bool = True, | |||
| remove_images: bool = True, | |||
| encoding: Optional[str] = None, | |||
| autodetect_encoding: bool = True, | |||
| ): | |||
| """Initialize with file path.""" | |||
| self._file_path = file_path | |||
| @@ -42,7 +30,8 @@ class MarkdownLoader(BaseLoader): | |||
| self._encoding = encoding | |||
| self._autodetect_encoding = autodetect_encoding | |||
| def load(self) -> list[Document]: | |||
| def extract(self) -> list[Document]: | |||
| """Load from file path.""" | |||
| tups = self.parse_tups(self._file_path) | |||
| documents = [] | |||
| for header, value in tups: | |||
| @@ -113,7 +102,6 @@ class MarkdownLoader(BaseLoader): | |||
| if self._autodetect_encoding: | |||
| detected_encodings = detect_file_encodings(filepath) | |||
| for encoding in detected_encodings: | |||
| logger.debug("Trying encoding: ", encoding.encoding) | |||
| try: | |||
| with open(filepath, encoding=encoding.encoding) as f: | |||
| content = f.read() | |||
| @@ -4,9 +4,10 @@ from typing import Any, Optional | |||
| import requests | |||
| from flask import current_app | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from flask_login import current_user | |||
| from langchain.schema import Document | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| from extensions.ext_database import db | |||
| from models.dataset import Document as DocumentModel | |||
| from models.source import DataSourceBinding | |||
| @@ -22,52 +23,37 @@ RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}" | |||
| HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3'] | |||
| class NotionLoader(BaseLoader): | |||
| class NotionExtractor(BaseExtractor): | |||
| def __init__( | |||
| self, | |||
| notion_access_token: str, | |||
| notion_workspace_id: str, | |||
| notion_obj_id: str, | |||
| notion_page_type: str, | |||
| document_model: Optional[DocumentModel] = None | |||
| document_model: Optional[DocumentModel] = None, | |||
| notion_access_token: Optional[str] = None | |||
| ): | |||
| self._notion_access_token = None | |||
| self._document_model = document_model | |||
| self._notion_workspace_id = notion_workspace_id | |||
| self._notion_obj_id = notion_obj_id | |||
| self._notion_page_type = notion_page_type | |||
| self._notion_access_token = notion_access_token | |||
| if not self._notion_access_token: | |||
| integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN') | |||
| if integration_token is None: | |||
| raise ValueError( | |||
| "Must specify `integration_token` or set environment " | |||
| "variable `NOTION_INTEGRATION_TOKEN`." | |||
| ) | |||
| self._notion_access_token = integration_token | |||
| @classmethod | |||
| def from_document(cls, document_model: DocumentModel): | |||
| data_source_info = document_model.data_source_info_dict | |||
| if not data_source_info or 'notion_page_id' not in data_source_info \ | |||
| or 'notion_workspace_id' not in data_source_info: | |||
| raise ValueError("no notion page found") | |||
| notion_workspace_id = data_source_info['notion_workspace_id'] | |||
| notion_obj_id = data_source_info['notion_page_id'] | |||
| notion_page_type = data_source_info['type'] | |||
| notion_access_token = cls._get_access_token(document_model.tenant_id, notion_workspace_id) | |||
| return cls( | |||
| notion_access_token=notion_access_token, | |||
| notion_workspace_id=notion_workspace_id, | |||
| notion_obj_id=notion_obj_id, | |||
| notion_page_type=notion_page_type, | |||
| document_model=document_model | |||
| ) | |||
| def load(self) -> list[Document]: | |||
| if notion_access_token: | |||
| self._notion_access_token = notion_access_token | |||
| else: | |||
| self._notion_access_token = self._get_access_token(current_user.current_tenant_id, | |||
| self._notion_workspace_id) | |||
| if not self._notion_access_token: | |||
| integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN') | |||
| if integration_token is None: | |||
| raise ValueError( | |||
| "Must specify `integration_token` or set environment " | |||
| "variable `NOTION_INTEGRATION_TOKEN`." | |||
| ) | |||
| self._notion_access_token = integration_token | |||
| def extract(self) -> list[Document]: | |||
| self.update_last_edited_time( | |||
| self._document_model | |||
| ) | |||
| @@ -0,0 +1,72 @@ | |||
| """Abstract interface for document loader implementations.""" | |||
| from collections.abc import Iterator | |||
| from typing import Optional | |||
| from core.rag.extractor.blod.blod import Blob | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| from core.rag.models.document import Document | |||
| from extensions.ext_storage import storage | |||
| class PdfExtractor(BaseExtractor): | |||
| """Load pdf files. | |||
| Args: | |||
| file_path: Path to the file to load. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| file_path: str, | |||
| file_cache_key: Optional[str] = None | |||
| ): | |||
| """Initialize with file path.""" | |||
| self._file_path = file_path | |||
| self._file_cache_key = file_cache_key | |||
| def extract(self) -> list[Document]: | |||
| plaintext_file_key = '' | |||
| plaintext_file_exists = False | |||
| if self._file_cache_key: | |||
| try: | |||
| text = storage.load(self._file_cache_key).decode('utf-8') | |||
| plaintext_file_exists = True | |||
| return [Document(page_content=text)] | |||
| except FileNotFoundError: | |||
| pass | |||
| documents = list(self.load()) | |||
| text_list = [] | |||
| for document in documents: | |||
| text_list.append(document.page_content) | |||
| text = "\n\n".join(text_list) | |||
| # save plaintext file for caching | |||
| if not plaintext_file_exists and plaintext_file_key: | |||
| storage.save(plaintext_file_key, text.encode('utf-8')) | |||
| return documents | |||
| def load( | |||
| self, | |||
| ) -> Iterator[Document]: | |||
| """Lazy load given path as pages.""" | |||
| blob = Blob.from_path(self._file_path) | |||
| yield from self.parse(blob) | |||
| def parse(self, blob: Blob) -> Iterator[Document]: | |||
| """Lazily parse the blob.""" | |||
| import pypdfium2 | |||
| with blob.as_bytes_io() as file_path: | |||
| pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True) | |||
| try: | |||
| for page_number, page in enumerate(pdf_reader): | |||
| text_page = page.get_textpage() | |||
| content = text_page.get_text_range() | |||
| text_page.close() | |||
| page.close() | |||
| metadata = {"source": blob.source, "page": page_number} | |||
| yield Document(page_content=content, metadata=metadata) | |||
| finally: | |||
| pdf_reader.close() | |||
| @@ -0,0 +1,50 @@ | |||
| """Abstract interface for document loader implementations.""" | |||
| from typing import Optional | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| from core.rag.extractor.helpers import detect_file_encodings | |||
| from core.rag.models.document import Document | |||
| class TextExtractor(BaseExtractor): | |||
| """Load text files. | |||
| Args: | |||
| file_path: Path to the file to load. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| file_path: str, | |||
| encoding: Optional[str] = None, | |||
| autodetect_encoding: bool = False | |||
| ): | |||
| """Initialize with file path.""" | |||
| self._file_path = file_path | |||
| self._encoding = encoding | |||
| self._autodetect_encoding = autodetect_encoding | |||
| def extract(self) -> list[Document]: | |||
| """Load from file path.""" | |||
| text = "" | |||
| try: | |||
| with open(self._file_path, encoding=self._encoding) as f: | |||
| text = f.read() | |||
| except UnicodeDecodeError as e: | |||
| if self._autodetect_encoding: | |||
| detected_encodings = detect_file_encodings(self._file_path) | |||
| for encoding in detected_encodings: | |||
| try: | |||
| with open(self._file_path, encoding=encoding.encoding) as f: | |||
| text = f.read() | |||
| break | |||
| except UnicodeDecodeError: | |||
| continue | |||
| else: | |||
| raise RuntimeError(f"Error loading {self._file_path}") from e | |||
| except Exception as e: | |||
| raise RuntimeError(f"Error loading {self._file_path}") from e | |||
| metadata = {"source": self._file_path} | |||
| return [Document(page_content=text, metadata=metadata)] | |||
| @@ -0,0 +1,61 @@ | |||
| import logging | |||
| import os | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| from core.rag.models.document import Document | |||
| logger = logging.getLogger(__name__) | |||
| class UnstructuredWordExtractor(BaseExtractor): | |||
| """Loader that uses unstructured to load word documents. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| file_path: str, | |||
| api_url: str, | |||
| ): | |||
| """Initialize with file path.""" | |||
| self._file_path = file_path | |||
| self._api_url = api_url | |||
| def extract(self) -> list[Document]: | |||
| from unstructured.__version__ import __version__ as __unstructured_version__ | |||
| from unstructured.file_utils.filetype import FileType, detect_filetype | |||
| unstructured_version = tuple( | |||
| [int(x) for x in __unstructured_version__.split(".")] | |||
| ) | |||
| # check the file extension | |||
| try: | |||
| import magic # noqa: F401 | |||
| is_doc = detect_filetype(self._file_path) == FileType.DOC | |||
| except ImportError: | |||
| _, extension = os.path.splitext(str(self._file_path)) | |||
| is_doc = extension == ".doc" | |||
| if is_doc and unstructured_version < (0, 4, 11): | |||
| raise ValueError( | |||
| f"You are on unstructured version {__unstructured_version__}. " | |||
| "Partitioning .doc files is only supported in unstructured>=0.4.11. " | |||
| "Please upgrade the unstructured package and try again." | |||
| ) | |||
| if is_doc: | |||
| from unstructured.partition.doc import partition_doc | |||
| elements = partition_doc(filename=self._file_path) | |||
| else: | |||
| from unstructured.partition.docx import partition_docx | |||
| elements = partition_docx(filename=self._file_path) | |||
| from unstructured.chunking.title import chunk_by_title | |||
| chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0) | |||
| documents = [] | |||
| for chunk in chunks: | |||
| text = chunk.text.strip() | |||
| documents.append(Document(page_content=text)) | |||
| return documents | |||
| @@ -2,13 +2,14 @@ import base64 | |||
| import logging | |||
| from bs4 import BeautifulSoup | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.schema import Document | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| from core.rag.models.document import Document | |||
| logger = logging.getLogger(__name__) | |||
| class UnstructuredEmailLoader(BaseLoader): | |||
| class UnstructuredEmailExtractor(BaseExtractor): | |||
| """Load msg files. | |||
| Args: | |||
| file_path: Path to the file to load. | |||
| @@ -23,7 +24,7 @@ class UnstructuredEmailLoader(BaseLoader): | |||
| self._file_path = file_path | |||
| self._api_url = api_url | |||
| def load(self) -> list[Document]: | |||
| def extract(self) -> list[Document]: | |||
| from unstructured.partition.email import partition_email | |||
| elements = partition_email(filename=self._file_path, api_url=self._api_url) | |||
| @@ -1,12 +1,12 @@ | |||
| import logging | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.schema import Document | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| from core.rag.models.document import Document | |||
| logger = logging.getLogger(__name__) | |||
| class UnstructuredMarkdownLoader(BaseLoader): | |||
| class UnstructuredMarkdownExtractor(BaseExtractor): | |||
| """Load md files. | |||
| @@ -33,7 +33,7 @@ class UnstructuredMarkdownLoader(BaseLoader): | |||
| self._file_path = file_path | |||
| self._api_url = api_url | |||
| def load(self) -> list[Document]: | |||
| def extract(self) -> list[Document]: | |||
| from unstructured.partition.md import partition_md | |||
| elements = partition_md(filename=self._file_path, api_url=self._api_url) | |||
| @@ -1,12 +1,12 @@ | |||
| import logging | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.schema import Document | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| from core.rag.models.document import Document | |||
| logger = logging.getLogger(__name__) | |||
| class UnstructuredMsgLoader(BaseLoader): | |||
| class UnstructuredMsgExtractor(BaseExtractor): | |||
| """Load msg files. | |||
| @@ -23,7 +23,7 @@ class UnstructuredMsgLoader(BaseLoader): | |||
| self._file_path = file_path | |||
| self._api_url = api_url | |||
| def load(self) -> list[Document]: | |||
| def extract(self) -> list[Document]: | |||
| from unstructured.partition.msg import partition_msg | |||
| elements = partition_msg(filename=self._file_path, api_url=self._api_url) | |||
| @@ -1,11 +1,12 @@ | |||
| import logging | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.schema import Document | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| from core.rag.models.document import Document | |||
| logger = logging.getLogger(__name__) | |||
| class UnstructuredPPTLoader(BaseLoader): | |||
| class UnstructuredPPTExtractor(BaseExtractor): | |||
| """Load msg files. | |||
| @@ -14,15 +15,15 @@ class UnstructuredPPTLoader(BaseLoader): | |||
| """ | |||
| def __init__( | |||
| self, | |||
| file_path: str, | |||
| api_url: str | |||
| self, | |||
| file_path: str, | |||
| api_url: str | |||
| ): | |||
| """Initialize with file path.""" | |||
| self._file_path = file_path | |||
| self._api_url = api_url | |||
| def load(self) -> list[Document]: | |||
| def extract(self) -> list[Document]: | |||
| from unstructured.partition.ppt import partition_ppt | |||
| elements = partition_ppt(filename=self._file_path, api_url=self._api_url) | |||
| @@ -1,10 +1,12 @@ | |||
| import logging | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.schema import Document | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| from core.rag.models.document import Document | |||
| logger = logging.getLogger(__name__) | |||
| class UnstructuredPPTXLoader(BaseLoader): | |||
| class UnstructuredPPTXExtractor(BaseExtractor): | |||
| """Load msg files. | |||
| @@ -13,15 +15,15 @@ class UnstructuredPPTXLoader(BaseLoader): | |||
| """ | |||
| def __init__( | |||
| self, | |||
| file_path: str, | |||
| api_url: str | |||
| self, | |||
| file_path: str, | |||
| api_url: str | |||
| ): | |||
| """Initialize with file path.""" | |||
| self._file_path = file_path | |||
| self._api_url = api_url | |||
| def load(self) -> list[Document]: | |||
| def extract(self) -> list[Document]: | |||
| from unstructured.partition.pptx import partition_pptx | |||
| elements = partition_pptx(filename=self._file_path, api_url=self._api_url) | |||
| @@ -1,12 +1,12 @@ | |||
| import logging | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.schema import Document | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| from core.rag.models.document import Document | |||
| logger = logging.getLogger(__name__) | |||
| class UnstructuredTextLoader(BaseLoader): | |||
| class UnstructuredTextExtractor(BaseExtractor): | |||
| """Load msg files. | |||
| @@ -23,7 +23,7 @@ class UnstructuredTextLoader(BaseLoader): | |||
| self._file_path = file_path | |||
| self._api_url = api_url | |||
| def load(self) -> list[Document]: | |||
| def extract(self) -> list[Document]: | |||
| from unstructured.partition.text import partition_text | |||
| elements = partition_text(filename=self._file_path, api_url=self._api_url) | |||
| @@ -1,12 +1,12 @@ | |||
| import logging | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.schema import Document | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| from core.rag.models.document import Document | |||
| logger = logging.getLogger(__name__) | |||
| class UnstructuredXmlLoader(BaseLoader): | |||
| class UnstructuredXmlExtractor(BaseExtractor): | |||
| """Load msg files. | |||
| @@ -23,7 +23,7 @@ class UnstructuredXmlLoader(BaseLoader): | |||
| self._file_path = file_path | |||
| self._api_url = api_url | |||
| def load(self) -> list[Document]: | |||
| def extract(self) -> list[Document]: | |||
| from unstructured.partition.xml import partition_xml | |||
| elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url) | |||
| @@ -0,0 +1,62 @@ | |||
| """Abstract interface for document loader implementations.""" | |||
| import os | |||
| import tempfile | |||
| from urllib.parse import urlparse | |||
| import requests | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| from core.rag.models.document import Document | |||
| class WordExtractor(BaseExtractor): | |||
| """Load pdf files. | |||
| Args: | |||
| file_path: Path to the file to load. | |||
| """ | |||
| def __init__(self, file_path: str): | |||
| """Initialize with file path.""" | |||
| self.file_path = file_path | |||
| if "~" in self.file_path: | |||
| self.file_path = os.path.expanduser(self.file_path) | |||
| # If the file is a web path, download it to a temporary file, and use that | |||
| if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path): | |||
| r = requests.get(self.file_path) | |||
| if r.status_code != 200: | |||
| raise ValueError( | |||
| "Check the url of your file; returned status code %s" | |||
| % r.status_code | |||
| ) | |||
| self.web_path = self.file_path | |||
| self.temp_file = tempfile.NamedTemporaryFile() | |||
| self.temp_file.write(r.content) | |||
| self.file_path = self.temp_file.name | |||
| elif not os.path.isfile(self.file_path): | |||
| raise ValueError("File path %s is not a valid file or url" % self.file_path) | |||
| def __del__(self) -> None: | |||
| if hasattr(self, "temp_file"): | |||
| self.temp_file.close() | |||
| def extract(self) -> list[Document]: | |||
| """Load given path as single page.""" | |||
| import docx2txt | |||
| return [ | |||
| Document( | |||
| page_content=docx2txt.process(self.file_path), | |||
| metadata={"source": self.file_path}, | |||
| ) | |||
| ] | |||
| @staticmethod | |||
| def _is_valid_url(url: str) -> bool: | |||
| """Check if the url is valid.""" | |||
| parsed = urlparse(url) | |||
| return bool(parsed.netloc) and bool(parsed.scheme) | |||
| @@ -0,0 +1,8 @@ | |||
| from enum import Enum | |||
| class IndexType(Enum): | |||
| PARAGRAPH_INDEX = "text_model" | |||
| QA_INDEX = "qa_model" | |||
| PARENT_CHILD_INDEX = "parent_child_index" | |||
| SUMMARY_INDEX = "summary_index" | |||
| @@ -0,0 +1,70 @@ | |||
| """Abstract interface for document loader implementations.""" | |||
| from abc import ABC, abstractmethod | |||
| from typing import Optional | |||
| from langchain.text_splitter import TextSplitter | |||
| from core.model_manager import ModelInstance | |||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||
| from core.rag.models.document import Document | |||
| from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter | |||
| from models.dataset import Dataset, DatasetProcessRule | |||
| class BaseIndexProcessor(ABC): | |||
| """Interface for extract files. | |||
| """ | |||
| @abstractmethod | |||
| def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def transform(self, documents: list[Document], **kwargs) -> list[Document]: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): | |||
| raise NotImplementedError | |||
| def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, | |||
| score_threshold: float, reranking_model: dict) -> list[Document]: | |||
| raise NotImplementedError | |||
| def _get_splitter(self, processing_rule: dict, | |||
| embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: | |||
| """ | |||
| Get the NodeParser object according to the processing rule. | |||
| """ | |||
| if processing_rule['mode'] == "custom": | |||
| # The user-defined segmentation rule | |||
| rules = processing_rule['rules'] | |||
| segmentation = rules["segmentation"] | |||
| if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000: | |||
| raise ValueError("Custom segment length should be between 50 and 1000.") | |||
| separator = segmentation["separator"] | |||
| if separator: | |||
| separator = separator.replace('\\n', '\n') | |||
| character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( | |||
| chunk_size=segmentation["max_tokens"], | |||
| chunk_overlap=0, | |||
| fixed_separator=separator, | |||
| separators=["\n\n", "。", ".", " ", ""], | |||
| embedding_model_instance=embedding_model_instance | |||
| ) | |||
| else: | |||
| # Automatic segmentation | |||
| character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( | |||
| chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], | |||
| chunk_overlap=0, | |||
| separators=["\n\n", "。", ".", " ", ""], | |||
| embedding_model_instance=embedding_model_instance | |||
| ) | |||
| return character_splitter | |||
| @@ -0,0 +1,28 @@ | |||
| """Abstract interface for document loader implementations.""" | |||
| from core.rag.index_processor.constant.index_type import IndexType | |||
| from core.rag.index_processor.index_processor_base import BaseIndexProcessor | |||
| from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor | |||
| from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor | |||
| class IndexProcessorFactory: | |||
| """IndexProcessorInit. | |||
| """ | |||
| def __init__(self, index_type: str): | |||
| self._index_type = index_type | |||
| def init_index_processor(self) -> BaseIndexProcessor: | |||
| """Init index processor.""" | |||
| if not self._index_type: | |||
| raise ValueError("Index type must be specified.") | |||
| if self._index_type == IndexType.PARAGRAPH_INDEX.value: | |||
| return ParagraphIndexProcessor() | |||
| elif self._index_type == IndexType.QA_INDEX.value: | |||
| return QAIndexProcessor() | |||
| else: | |||
| raise ValueError(f"Index type {self._index_type} is not supported.") | |||
| @@ -0,0 +1,92 @@ | |||
| """Paragraph index processor.""" | |||
| import uuid | |||
| from typing import Optional | |||
| from core.rag.cleaner.clean_processor import CleanProcessor | |||
| from core.rag.datasource.keyword.keyword_factory import Keyword | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rag.datasource.vdb.vector_factory import Vector | |||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||
| from core.rag.extractor.extract_processor import ExtractProcessor | |||
| from core.rag.index_processor.index_processor_base import BaseIndexProcessor | |||
| from core.rag.models.document import Document | |||
| from libs import helper | |||
| from models.dataset import Dataset | |||
| class ParagraphIndexProcessor(BaseIndexProcessor): | |||
| def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: | |||
| text_docs = ExtractProcessor.extract(extract_setting=extract_setting, | |||
| is_automatic=kwargs.get('process_rule_mode') == "automatic") | |||
| return text_docs | |||
| def transform(self, documents: list[Document], **kwargs) -> list[Document]: | |||
| # Split the text documents into nodes. | |||
| splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), | |||
| embedding_model_instance=kwargs.get('embedding_model_instance')) | |||
| all_documents = [] | |||
| for document in documents: | |||
| # document clean | |||
| document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) | |||
| document.page_content = document_text | |||
| # parse document to nodes | |||
| document_nodes = splitter.split_documents([document]) | |||
| split_documents = [] | |||
| for document_node in document_nodes: | |||
| if document_node.page_content.strip(): | |||
| doc_id = str(uuid.uuid4()) | |||
| hash = helper.generate_text_hash(document_node.page_content) | |||
| document_node.metadata['doc_id'] = doc_id | |||
| document_node.metadata['doc_hash'] = hash | |||
| # delete Spliter character | |||
| page_content = document_node.page_content | |||
| if page_content.startswith(".") or page_content.startswith("。"): | |||
| page_content = page_content[1:] | |||
| else: | |||
| page_content = page_content | |||
| document_node.page_content = page_content | |||
| split_documents.append(document_node) | |||
| all_documents.extend(split_documents) | |||
| return all_documents | |||
| def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): | |||
| if dataset.indexing_technique == 'high_quality': | |||
| vector = Vector(dataset) | |||
| vector.create(documents) | |||
| if with_keywords: | |||
| keyword = Keyword(dataset) | |||
| keyword.create(documents) | |||
| def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): | |||
| if dataset.indexing_technique == 'high_quality': | |||
| vector = Vector(dataset) | |||
| if node_ids: | |||
| vector.delete_by_ids(node_ids) | |||
| else: | |||
| vector.delete() | |||
| if with_keywords: | |||
| keyword = Keyword(dataset) | |||
| if node_ids: | |||
| keyword.delete_by_ids(node_ids) | |||
| else: | |||
| keyword.delete() | |||
| def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, | |||
| score_threshold: float, reranking_model: dict) -> list[Document]: | |||
| # Set search parameters. | |||
| results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query, | |||
| top_k=top_k, score_threshold=score_threshold, | |||
| reranking_model=reranking_model) | |||
| # Organize results. | |||
| docs = [] | |||
| for result in results: | |||
| metadata = result.metadata | |||
| metadata['score'] = result.score | |||
| if result.score > score_threshold: | |||
| doc = Document(page_content=result.page_content, metadata=metadata) | |||
| docs.append(doc) | |||
| return docs | |||
| @@ -0,0 +1,161 @@ | |||
| """Paragraph index processor.""" | |||
| import logging | |||
| import re | |||
| import threading | |||
| import uuid | |||
| from typing import Optional | |||
| import pandas as pd | |||
| from flask import Flask, current_app | |||
| from flask_login import current_user | |||
| from werkzeug.datastructures import FileStorage | |||
| from core.generator.llm_generator import LLMGenerator | |||
| from core.rag.cleaner.clean_processor import CleanProcessor | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rag.datasource.vdb.vector_factory import Vector | |||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||
| from core.rag.extractor.extract_processor import ExtractProcessor | |||
| from core.rag.index_processor.index_processor_base import BaseIndexProcessor | |||
| from core.rag.models.document import Document | |||
| from libs import helper | |||
| from models.dataset import Dataset | |||
| class QAIndexProcessor(BaseIndexProcessor): | |||
| def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: | |||
| text_docs = ExtractProcessor.extract(extract_setting=extract_setting, | |||
| is_automatic=kwargs.get('process_rule_mode') == "automatic") | |||
| return text_docs | |||
| def transform(self, documents: list[Document], **kwargs) -> list[Document]: | |||
| splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), | |||
| embedding_model_instance=None) | |||
| # Split the text documents into nodes. | |||
| all_documents = [] | |||
| all_qa_documents = [] | |||
| for document in documents: | |||
| # document clean | |||
| document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) | |||
| document.page_content = document_text | |||
| # parse document to nodes | |||
| document_nodes = splitter.split_documents([document]) | |||
| split_documents = [] | |||
| for document_node in document_nodes: | |||
| if document_node.page_content.strip(): | |||
| doc_id = str(uuid.uuid4()) | |||
| hash = helper.generate_text_hash(document_node.page_content) | |||
| document_node.metadata['doc_id'] = doc_id | |||
| document_node.metadata['doc_hash'] = hash | |||
| # delete Spliter character | |||
| page_content = document_node.page_content | |||
| if page_content.startswith(".") or page_content.startswith("。"): | |||
| page_content = page_content[1:] | |||
| else: | |||
| page_content = page_content | |||
| document_node.page_content = page_content | |||
| split_documents.append(document_node) | |||
| all_documents.extend(split_documents) | |||
| for i in range(0, len(all_documents), 10): | |||
| threads = [] | |||
| sub_documents = all_documents[i:i + 10] | |||
| for doc in sub_documents: | |||
| document_format_thread = threading.Thread(target=self._format_qa_document, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'tenant_id': current_user.current_tenant.id, | |||
| 'document_node': doc, | |||
| 'all_qa_documents': all_qa_documents, | |||
| 'document_language': kwargs.get('document_language', 'English')}) | |||
| threads.append(document_format_thread) | |||
| document_format_thread.start() | |||
| for thread in threads: | |||
| thread.join() | |||
| return all_qa_documents | |||
| def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: | |||
| # check file type | |||
| if not file.filename.endswith('.csv'): | |||
| raise ValueError("Invalid file type. Only CSV files are allowed") | |||
| try: | |||
| # Skip the first row | |||
| df = pd.read_csv(file) | |||
| text_docs = [] | |||
| for index, row in df.iterrows(): | |||
| data = Document(page_content=row[0], metadata={'answer': row[1]}) | |||
| text_docs.append(data) | |||
| if len(text_docs) == 0: | |||
| raise ValueError("The CSV file is empty.") | |||
| except Exception as e: | |||
| raise ValueError(str(e)) | |||
| return text_docs | |||
| def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): | |||
| if dataset.indexing_technique == 'high_quality': | |||
| vector = Vector(dataset) | |||
| vector.create(documents) | |||
| def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): | |||
| vector = Vector(dataset) | |||
| if node_ids: | |||
| vector.delete_by_ids(node_ids) | |||
| else: | |||
| vector.delete() | |||
| def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, | |||
| score_threshold: float, reranking_model: dict): | |||
| # Set search parameters. | |||
| results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query, | |||
| top_k=top_k, score_threshold=score_threshold, | |||
| reranking_model=reranking_model) | |||
| # Organize results. | |||
| docs = [] | |||
| for result in results: | |||
| metadata = result.metadata | |||
| metadata['score'] = result.score | |||
| if result.score > score_threshold: | |||
| doc = Document(page_content=result.page_content, metadata=metadata) | |||
| docs.append(doc) | |||
| return docs | |||
| def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language): | |||
| format_documents = [] | |||
| if document_node.page_content is None or not document_node.page_content.strip(): | |||
| return | |||
| with flask_app.app_context(): | |||
| try: | |||
| # qa model document | |||
| response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language) | |||
| document_qa_list = self._format_split_text(response) | |||
| qa_documents = [] | |||
| for result in document_qa_list: | |||
| qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy()) | |||
| doc_id = str(uuid.uuid4()) | |||
| hash = helper.generate_text_hash(result['question']) | |||
| qa_document.metadata['answer'] = result['answer'] | |||
| qa_document.metadata['doc_id'] = doc_id | |||
| qa_document.metadata['doc_hash'] = hash | |||
| qa_documents.append(qa_document) | |||
| format_documents.extend(qa_documents) | |||
| except Exception as e: | |||
| logging.exception(e) | |||
| all_qa_documents.extend(format_documents) | |||
| def _format_split_text(self, text): | |||
| regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" | |||
| matches = re.findall(regex, text, re.UNICODE) | |||
| return [ | |||
| { | |||
| "question": q, | |||
| "answer": re.sub(r"\n\s*", "\n", a.strip()) | |||
| } | |||
| for q, a in matches if q and a | |||
| ] | |||
| @@ -0,0 +1,16 @@ | |||
| from typing import Optional | |||
| from pydantic import BaseModel, Field | |||
| class Document(BaseModel): | |||
| """Class for storing a piece of text and associated metadata.""" | |||
| page_content: str | |||
| """Arbitrary metadata about the page content (e.g., source, relationships to other | |||
| documents, etc.). | |||
| """ | |||
| metadata: Optional[dict] = Field(default_factory=dict) | |||
| @@ -21,9 +21,9 @@ from pydantic import BaseModel, Field | |||
| from regex import regex | |||
| from core.chain.llm_chain import LLMChain | |||
| from core.data_loader import file_extractor | |||
| from core.data_loader.file_extractor import FileExtractor | |||
| from core.entities.application_entities import ModelConfigEntity | |||
| from core.rag.extractor import extract_processor | |||
| from core.rag.extractor.extract_processor import ExtractProcessor | |||
| FULL_TEMPLATE = """ | |||
| TITLE: {title} | |||
| @@ -146,7 +146,7 @@ def get_url(url: str) -> str: | |||
| headers = { | |||
| "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" | |||
| } | |||
| supported_content_types = file_extractor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] | |||
| supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] | |||
| head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10)) | |||
| @@ -158,8 +158,8 @@ def get_url(url: str) -> str: | |||
| if main_content_type not in supported_content_types: | |||
| return "Unsupported content-type [{}] of URL.".format(main_content_type) | |||
| if main_content_type in file_extractor.SUPPORT_URL_CONTENT_TYPES: | |||
| return FileExtractor.load_from_url(url, return_text=True) | |||
| if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: | |||
| return ExtractProcessor.load_from_url(url, return_text=True) | |||
| response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30)) | |||
| a = extract_using_readabilipy(response.text) | |||
| @@ -6,15 +6,12 @@ from langchain.tools import BaseTool | |||
| from pydantic import BaseModel, Field | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| from core.embedding.cached_embedding import CacheEmbedding | |||
| from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | |||
| from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rerank.rerank import RerankRunner | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, Document, DocumentSegment | |||
| from services.retrieval_service import RetrievalService | |||
| default_retrieval_model = { | |||
| 'search_method': 'semantic_search', | |||
| @@ -174,76 +171,24 @@ class DatasetMultiRetrieverTool(BaseTool): | |||
| if dataset.indexing_technique == "economy": | |||
| # use keyword table query | |||
| kw_table_index = KeywordTableIndex( | |||
| dataset=dataset, | |||
| config=KeywordTableConfig( | |||
| max_keywords_per_chunk=5 | |||
| ) | |||
| ) | |||
| documents = kw_table_index.search(query, search_kwargs={'k': self.top_k}) | |||
| documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], | |||
| dataset_id=dataset.id, | |||
| query=query, | |||
| top_k=self.top_k | |||
| ) | |||
| if documents: | |||
| all_documents.extend(documents) | |||
| else: | |||
| try: | |||
| model_manager = ModelManager() | |||
| embedding_model = model_manager.get_model_instance( | |||
| tenant_id=dataset.tenant_id, | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model | |||
| ) | |||
| except LLMBadRequestError: | |||
| return [] | |||
| except ProviderTokenNotInitError: | |||
| return [] | |||
| embeddings = CacheEmbedding(embedding_model) | |||
| documents = [] | |||
| threads = [] | |||
| if self.top_k > 0: | |||
| # retrieval_model source with semantic | |||
| if retrieval_model['search_method'] == 'semantic_search' or retrieval_model[ | |||
| 'search_method'] == 'hybrid_search': | |||
| embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset_id': str(dataset.id), | |||
| 'query': query, | |||
| 'top_k': self.top_k, | |||
| 'score_threshold': self.score_threshold, | |||
| 'reranking_model': None, | |||
| 'all_documents': documents, | |||
| 'search_method': 'hybrid_search', | |||
| 'embeddings': embeddings | |||
| }) | |||
| threads.append(embedding_thread) | |||
| embedding_thread.start() | |||
| # retrieval_model source with full text | |||
| if retrieval_model['search_method'] == 'full_text_search' or retrieval_model[ | |||
| 'search_method'] == 'hybrid_search': | |||
| full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, | |||
| kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset_id': str(dataset.id), | |||
| 'query': query, | |||
| 'search_method': 'hybrid_search', | |||
| 'embeddings': embeddings, | |||
| 'score_threshold': retrieval_model[ | |||
| 'score_threshold'] if retrieval_model[ | |||
| 'score_threshold_enabled'] else None, | |||
| 'top_k': self.top_k, | |||
| 'reranking_model': retrieval_model[ | |||
| 'reranking_model'] if retrieval_model[ | |||
| 'reranking_enable'] else None, | |||
| 'all_documents': documents | |||
| }) | |||
| threads.append(full_text_index_thread) | |||
| full_text_index_thread.start() | |||
| for thread in threads: | |||
| thread.join() | |||
| # retrieval source | |||
| documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], | |||
| dataset_id=dataset.id, | |||
| query=query, | |||
| top_k=self.top_k, | |||
| score_threshold=retrieval_model['score_threshold'] | |||
| if retrieval_model['score_threshold_enabled'] else None, | |||
| reranking_model=retrieval_model['reranking_model'] | |||
| if retrieval_model['reranking_enable'] else None | |||
| ) | |||
| all_documents.extend(documents) | |||
| @@ -1,20 +1,12 @@ | |||
| import threading | |||
| from typing import Optional | |||
| from flask import current_app | |||
| from langchain.tools import BaseTool | |||
| from pydantic import BaseModel, Field | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| from core.embedding.cached_embedding import CacheEmbedding | |||
| from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from core.rerank.rerank import RerankRunner | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, Document, DocumentSegment | |||
| from services.retrieval_service import RetrievalService | |||
| default_retrieval_model = { | |||
| 'search_method': 'semantic_search', | |||
| @@ -77,94 +69,24 @@ class DatasetRetrieverTool(BaseTool): | |||
| retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | |||
| if dataset.indexing_technique == "economy": | |||
| # use keyword table query | |||
| kw_table_index = KeywordTableIndex( | |||
| dataset=dataset, | |||
| config=KeywordTableConfig( | |||
| max_keywords_per_chunk=5 | |||
| ) | |||
| ) | |||
| documents = kw_table_index.search(query, search_kwargs={'k': self.top_k}) | |||
| documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], | |||
| dataset_id=dataset.id, | |||
| query=query, | |||
| top_k=self.top_k | |||
| ) | |||
| return str("\n".join([document.page_content for document in documents])) | |||
| else: | |||
| # get embedding model instance | |||
| try: | |||
| model_manager = ModelManager() | |||
| embedding_model = model_manager.get_model_instance( | |||
| tenant_id=dataset.tenant_id, | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model | |||
| ) | |||
| except InvokeAuthorizationError: | |||
| return '' | |||
| embeddings = CacheEmbedding(embedding_model) | |||
| documents = [] | |||
| threads = [] | |||
| if self.top_k > 0: | |||
| # retrieval source with semantic | |||
| if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': | |||
| embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset_id': str(dataset.id), | |||
| 'query': query, | |||
| 'top_k': self.top_k, | |||
| 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ | |||
| 'score_threshold_enabled'] else None, | |||
| 'reranking_model': retrieval_model['reranking_model'] if retrieval_model[ | |||
| 'reranking_enable'] else None, | |||
| 'all_documents': documents, | |||
| 'search_method': retrieval_model['search_method'], | |||
| 'embeddings': embeddings | |||
| }) | |||
| threads.append(embedding_thread) | |||
| embedding_thread.start() | |||
| # retrieval_model source with full text | |||
| if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': | |||
| full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset_id': str(dataset.id), | |||
| 'query': query, | |||
| 'search_method': retrieval_model['search_method'], | |||
| 'embeddings': embeddings, | |||
| 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ | |||
| 'score_threshold_enabled'] else None, | |||
| 'top_k': self.top_k, | |||
| 'reranking_model': retrieval_model['reranking_model'] if retrieval_model[ | |||
| 'reranking_enable'] else None, | |||
| 'all_documents': documents | |||
| }) | |||
| threads.append(full_text_index_thread) | |||
| full_text_index_thread.start() | |||
| for thread in threads: | |||
| thread.join() | |||
| # hybrid search: rerank after all documents have been searched | |||
| if retrieval_model['search_method'] == 'hybrid_search': | |||
| # get rerank model instance | |||
| try: | |||
| model_manager = ModelManager() | |||
| rerank_model_instance = model_manager.get_model_instance( | |||
| tenant_id=dataset.tenant_id, | |||
| provider=retrieval_model['reranking_model']['reranking_provider_name'], | |||
| model_type=ModelType.RERANK, | |||
| model=retrieval_model['reranking_model']['reranking_model_name'] | |||
| ) | |||
| except InvokeAuthorizationError: | |||
| return '' | |||
| rerank_runner = RerankRunner(rerank_model_instance) | |||
| documents = rerank_runner.run( | |||
| query=query, | |||
| documents=documents, | |||
| score_threshold=retrieval_model['score_threshold'] if retrieval_model[ | |||
| 'score_threshold_enabled'] else None, | |||
| top_n=self.top_k | |||
| ) | |||
| # retrieval source | |||
| documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], | |||
| dataset_id=dataset.id, | |||
| query=query, | |||
| top_k=self.top_k, | |||
| score_threshold=retrieval_model['score_threshold'] | |||
| if retrieval_model['score_threshold_enabled'] else None, | |||
| reranking_model=retrieval_model['reranking_model'] | |||
| if retrieval_model['reranking_enable'] else None | |||
| ) | |||
| else: | |||
| documents = [] | |||
| @@ -234,4 +156,4 @@ class DatasetRetrieverTool(BaseTool): | |||
| return str("\n".join(document_context_list)) | |||
| async def _arun(self, tool_input: str) -> str: | |||
| raise NotImplementedError() | |||
| raise NotImplementedError() | |||
| @@ -21,9 +21,9 @@ from pydantic import BaseModel, Field | |||
| from regex import regex | |||
| from core.chain.llm_chain import LLMChain | |||
| from core.data_loader import file_extractor | |||
| from core.data_loader.file_extractor import FileExtractor | |||
| from core.entities.application_entities import ModelConfigEntity | |||
| from core.rag.extractor import extract_processor | |||
| from core.rag.extractor.extract_processor import ExtractProcessor | |||
| FULL_TEMPLATE = """ | |||
| TITLE: {title} | |||
| @@ -149,7 +149,7 @@ def get_url(url: str, user_agent: str = None) -> str: | |||
| if user_agent: | |||
| headers["User-Agent"] = user_agent | |||
| supported_content_types = file_extractor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] | |||
| supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] | |||
| head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10)) | |||
| @@ -161,8 +161,8 @@ def get_url(url: str, user_agent: str = None) -> str: | |||
| if main_content_type not in supported_content_types: | |||
| return "Unsupported content-type [{}] of URL.".format(main_content_type) | |||
| if main_content_type in file_extractor.SUPPORT_URL_CONTENT_TYPES: | |||
| return FileExtractor.load_from_url(url, return_text=True) | |||
| if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: | |||
| return ExtractProcessor.load_from_url(url, return_text=True) | |||
| response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30)) | |||
| a = extract_using_readabilipy(response.text) | |||
| @@ -1,56 +0,0 @@ | |||
| from core.vector_store.vector.milvus import Milvus | |||
| class MilvusVectorStore(Milvus): | |||
| def del_texts(self, where_filter: dict): | |||
| if not where_filter: | |||
| raise ValueError('where_filter must not be empty') | |||
| self.col.delete(where_filter.get('filter')) | |||
| def del_text(self, uuid: str) -> None: | |||
| expr = f"id == {uuid}" | |||
| self.col.delete(expr) | |||
| def text_exists(self, uuid: str) -> bool: | |||
| result = self.col.query( | |||
| expr=f'metadata["doc_id"] == "{uuid}"', | |||
| output_fields=["id"] | |||
| ) | |||
| return len(result) > 0 | |||
| def get_ids_by_document_id(self, document_id: str): | |||
| result = self.col.query( | |||
| expr=f'metadata["document_id"] == "{document_id}"', | |||
| output_fields=["id"] | |||
| ) | |||
| if result: | |||
| return [item["id"] for item in result] | |||
| else: | |||
| return None | |||
| def get_ids_by_metadata_field(self, key: str, value: str): | |||
| result = self.col.query( | |||
| expr=f'metadata["{key}"] == "{value}"', | |||
| output_fields=["id"] | |||
| ) | |||
| if result: | |||
| return [item["id"] for item in result] | |||
| else: | |||
| return None | |||
| def get_ids_by_doc_ids(self, doc_ids: list): | |||
| result = self.col.query( | |||
| expr=f'metadata["doc_id"] in {doc_ids}', | |||
| output_fields=["id"] | |||
| ) | |||
| if result: | |||
| return [item["id"] for item in result] | |||
| else: | |||
| return None | |||
| def delete(self): | |||
| from pymilvus import utility | |||
| utility.drop_collection(self.collection_name, None, self.alias) | |||
| @@ -1,76 +0,0 @@ | |||
| from typing import Any, cast | |||
| from langchain.schema import Document | |||
| from qdrant_client.http.models import Filter, FilterSelector, PointIdsList | |||
| from qdrant_client.local.qdrant_local import QdrantLocal | |||
| from core.vector_store.vector.qdrant import Qdrant | |||
| class QdrantVectorStore(Qdrant): | |||
| def del_texts(self, filter: Filter): | |||
| if not filter: | |||
| raise ValueError('filter must not be empty') | |||
| self._reload_if_needed() | |||
| self.client.delete( | |||
| collection_name=self.collection_name, | |||
| points_selector=FilterSelector( | |||
| filter=filter | |||
| ), | |||
| ) | |||
| def del_text(self, uuid: str) -> None: | |||
| self._reload_if_needed() | |||
| self.client.delete( | |||
| collection_name=self.collection_name, | |||
| points_selector=PointIdsList( | |||
| points=[uuid], | |||
| ), | |||
| ) | |||
| def text_exists(self, uuid: str) -> bool: | |||
| self._reload_if_needed() | |||
| response = self.client.retrieve( | |||
| collection_name=self.collection_name, | |||
| ids=[uuid] | |||
| ) | |||
| return len(response) > 0 | |||
| def delete(self): | |||
| self._reload_if_needed() | |||
| self.client.delete_collection(collection_name=self.collection_name) | |||
| def delete_group(self): | |||
| self._reload_if_needed() | |||
| self.client.delete_collection(collection_name=self.collection_name) | |||
| @classmethod | |||
| def _document_from_scored_point( | |||
| cls, | |||
| scored_point: Any, | |||
| content_payload_key: str, | |||
| metadata_payload_key: str, | |||
| ) -> Document: | |||
| if scored_point.payload.get('doc_id'): | |||
| return Document( | |||
| page_content=scored_point.payload.get(content_payload_key), | |||
| metadata={'doc_id': scored_point.id} | |||
| ) | |||
| return Document( | |||
| page_content=scored_point.payload.get(content_payload_key), | |||
| metadata=scored_point.payload.get(metadata_payload_key) or {}, | |||
| ) | |||
| def _reload_if_needed(self): | |||
| if isinstance(self.client, QdrantLocal): | |||
| self.client = cast(QdrantLocal, self.client) | |||
| self.client._load() | |||
| @@ -1,852 +0,0 @@ | |||
| """Wrapper around the Milvus vector database.""" | |||
| from __future__ import annotations | |||
| import logging | |||
| from collections.abc import Iterable, Sequence | |||
| from typing import Any, Optional, Union | |||
| from uuid import uuid4 | |||
| import numpy as np | |||
| from langchain.docstore.document import Document | |||
| from langchain.embeddings.base import Embeddings | |||
| from langchain.vectorstores.base import VectorStore | |||
| from langchain.vectorstores.utils import maximal_marginal_relevance | |||
| logger = logging.getLogger(__name__) | |||
| DEFAULT_MILVUS_CONNECTION = { | |||
| "host": "localhost", | |||
| "port": "19530", | |||
| "user": "", | |||
| "password": "", | |||
| "secure": False, | |||
| } | |||
| class Milvus(VectorStore): | |||
| """Initialize wrapper around the milvus vector database. | |||
| In order to use this you need to have `pymilvus` installed and a | |||
| running Milvus | |||
| See the following documentation for how to run a Milvus instance: | |||
| https://milvus.io/docs/install_standalone-docker.md | |||
| If looking for a hosted Milvus, take a look at this documentation: | |||
| https://zilliz.com/cloud and make use of the Zilliz vectorstore found in | |||
| this project, | |||
| IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA. | |||
| Args: | |||
| embedding_function (Embeddings): Function used to embed the text. | |||
| collection_name (str): Which Milvus collection to use. Defaults to | |||
| "LangChainCollection". | |||
| connection_args (Optional[dict[str, any]]): The connection args used for | |||
| this class comes in the form of a dict. | |||
| consistency_level (str): The consistency level to use for a collection. | |||
| Defaults to "Session". | |||
| index_params (Optional[dict]): Which index params to use. Defaults to | |||
| HNSW/AUTOINDEX depending on service. | |||
| search_params (Optional[dict]): Which search params to use. Defaults to | |||
| default of index. | |||
| drop_old (Optional[bool]): Whether to drop the current collection. Defaults | |||
| to False. | |||
| The connection args used for this class comes in the form of a dict, | |||
| here are a few of the options: | |||
| address (str): The actual address of Milvus | |||
| instance. Example address: "localhost:19530" | |||
| uri (str): The uri of Milvus instance. Example uri: | |||
| "http://randomwebsite:19530", | |||
| "tcp:foobarsite:19530", | |||
| "https://ok.s3.south.com:19530". | |||
| host (str): The host of Milvus instance. Default at "localhost", | |||
| PyMilvus will fill in the default host if only port is provided. | |||
| port (str/int): The port of Milvus instance. Default at 19530, PyMilvus | |||
| will fill in the default port if only host is provided. | |||
| user (str): Use which user to connect to Milvus instance. If user and | |||
| password are provided, we will add related header in every RPC call. | |||
| password (str): Required when user is provided. The password | |||
| corresponding to the user. | |||
| secure (bool): Default is false. If set to true, tls will be enabled. | |||
| client_key_path (str): If use tls two-way authentication, need to | |||
| write the client.key path. | |||
| client_pem_path (str): If use tls two-way authentication, need to | |||
| write the client.pem path. | |||
| ca_pem_path (str): If use tls two-way authentication, need to write | |||
| the ca.pem path. | |||
| server_pem_path (str): If use tls one-way authentication, need to | |||
| write the server.pem path. | |||
| server_name (str): If use tls, need to write the common name. | |||
| Example: | |||
| .. code-block:: python | |||
| from langchain import Milvus | |||
| from langchain.embeddings import OpenAIEmbeddings | |||
| embedding = OpenAIEmbeddings() | |||
| # Connect to a milvus instance on localhost | |||
| milvus_store = Milvus( | |||
| embedding_function = Embeddings, | |||
| collection_name = "LangChainCollection", | |||
| drop_old = True, | |||
| ) | |||
| Raises: | |||
| ValueError: If the pymilvus python package is not installed. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| embedding_function: Embeddings, | |||
| collection_name: str = "LangChainCollection", | |||
| connection_args: Optional[dict[str, Any]] = None, | |||
| consistency_level: str = "Session", | |||
| index_params: Optional[dict] = None, | |||
| search_params: Optional[dict] = None, | |||
| drop_old: Optional[bool] = False, | |||
| ): | |||
| """Initialize the Milvus vector store.""" | |||
| try: | |||
| from pymilvus import Collection, utility | |||
| except ImportError: | |||
| raise ValueError( | |||
| "Could not import pymilvus python package. " | |||
| "Please install it with `pip install pymilvus`." | |||
| ) | |||
| # Default search params when one is not provided. | |||
| self.default_search_params = { | |||
| "IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, | |||
| "IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}}, | |||
| "IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}}, | |||
| "HNSW": {"metric_type": "L2", "params": {"ef": 10}}, | |||
| "RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}}, | |||
| "RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}}, | |||
| "RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}}, | |||
| "IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}}, | |||
| "ANNOY": {"metric_type": "L2", "params": {"search_k": 10}}, | |||
| "AUTOINDEX": {"metric_type": "L2", "params": {}}, | |||
| } | |||
| self.embedding_func = embedding_function | |||
| self.collection_name = collection_name | |||
| self.index_params = index_params | |||
| self.search_params = search_params | |||
| self.consistency_level = consistency_level | |||
| # In order for a collection to be compatible, pk needs to be auto'id and int | |||
| self._primary_field = "id" | |||
| # In order for compatibility, the text field will need to be called "text" | |||
| self._text_field = "page_content" | |||
| # In order for compatibility, the vector field needs to be called "vector" | |||
| self._vector_field = "vectors" | |||
| # In order for compatibility, the metadata field will need to be called "metadata" | |||
| self._metadata_field = "metadata" | |||
| self.fields: list[str] = [] | |||
| # Create the connection to the server | |||
| if connection_args is None: | |||
| connection_args = DEFAULT_MILVUS_CONNECTION | |||
| self.alias = self._create_connection_alias(connection_args) | |||
| self.col: Optional[Collection] = None | |||
| # Grab the existing collection if it exists | |||
| if utility.has_collection(self.collection_name, using=self.alias): | |||
| self.col = Collection( | |||
| self.collection_name, | |||
| using=self.alias, | |||
| ) | |||
| # If need to drop old, drop it | |||
| if drop_old and isinstance(self.col, Collection): | |||
| self.col.drop() | |||
| self.col = None | |||
| # Initialize the vector store | |||
| self._init() | |||
| @property | |||
| def embeddings(self) -> Embeddings: | |||
| return self.embedding_func | |||
| def _create_connection_alias(self, connection_args: dict) -> str: | |||
| """Create the connection to the Milvus server.""" | |||
| from pymilvus import MilvusException, connections | |||
| # Grab the connection arguments that are used for checking existing connection | |||
| host: str = connection_args.get("host", None) | |||
| port: Union[str, int] = connection_args.get("port", None) | |||
| address: str = connection_args.get("address", None) | |||
| uri: str = connection_args.get("uri", None) | |||
| user = connection_args.get("user", None) | |||
| # Order of use is host/port, uri, address | |||
| if host is not None and port is not None: | |||
| given_address = str(host) + ":" + str(port) | |||
| elif uri is not None: | |||
| given_address = uri.split("https://")[1] | |||
| elif address is not None: | |||
| given_address = address | |||
| else: | |||
| given_address = None | |||
| logger.debug("Missing standard address type for reuse atttempt") | |||
| # User defaults to empty string when getting connection info | |||
| if user is not None: | |||
| tmp_user = user | |||
| else: | |||
| tmp_user = "" | |||
| # If a valid address was given, then check if a connection exists | |||
| if given_address is not None: | |||
| for con in connections.list_connections(): | |||
| addr = connections.get_connection_addr(con[0]) | |||
| if ( | |||
| con[1] | |||
| and ("address" in addr) | |||
| and (addr["address"] == given_address) | |||
| and ("user" in addr) | |||
| and (addr["user"] == tmp_user) | |||
| ): | |||
| logger.debug("Using previous connection: %s", con[0]) | |||
| return con[0] | |||
| # Generate a new connection if one doesn't exist | |||
| alias = uuid4().hex | |||
| try: | |||
| connections.connect(alias=alias, **connection_args) | |||
| logger.debug("Created new connection using: %s", alias) | |||
| return alias | |||
| except MilvusException as e: | |||
| logger.error("Failed to create new connection using: %s", alias) | |||
| raise e | |||
| def _init( | |||
| self, embeddings: Optional[list] = None, metadatas: Optional[list[dict]] = None | |||
| ) -> None: | |||
| if embeddings is not None: | |||
| self._create_collection(embeddings, metadatas) | |||
| self._extract_fields() | |||
| self._create_index() | |||
| self._create_search_params() | |||
| self._load() | |||
| def _create_collection( | |||
| self, embeddings: list, metadatas: Optional[list[dict]] = None | |||
| ) -> None: | |||
| from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, MilvusException | |||
| from pymilvus.orm.types import infer_dtype_bydata | |||
| # Determine embedding dim | |||
| dim = len(embeddings[0]) | |||
| fields = [] | |||
| # Determine metadata schema | |||
| # if metadatas: | |||
| # # Create FieldSchema for each entry in metadata. | |||
| # for key, value in metadatas[0].items(): | |||
| # # Infer the corresponding datatype of the metadata | |||
| # dtype = infer_dtype_bydata(value) | |||
| # # Datatype isn't compatible | |||
| # if dtype == DataType.UNKNOWN or dtype == DataType.NONE: | |||
| # logger.error( | |||
| # "Failure to create collection, unrecognized dtype for key: %s", | |||
| # key, | |||
| # ) | |||
| # raise ValueError(f"Unrecognized datatype for {key}.") | |||
| # # Dataype is a string/varchar equivalent | |||
| # elif dtype == DataType.VARCHAR: | |||
| # fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535)) | |||
| # else: | |||
| # fields.append(FieldSchema(key, dtype)) | |||
| if metadatas: | |||
| fields.append(FieldSchema(self._metadata_field, DataType.JSON, max_length=65_535)) | |||
| # Create the text field | |||
| fields.append( | |||
| FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535) | |||
| ) | |||
| # Create the primary key field | |||
| fields.append( | |||
| FieldSchema( | |||
| self._primary_field, DataType.INT64, is_primary=True, auto_id=True | |||
| ) | |||
| ) | |||
| # Create the vector field, supports binary or float vectors | |||
| fields.append( | |||
| FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim) | |||
| ) | |||
| # Create the schema for the collection | |||
| schema = CollectionSchema(fields) | |||
| # Create the collection | |||
| try: | |||
| self.col = Collection( | |||
| name=self.collection_name, | |||
| schema=schema, | |||
| consistency_level=self.consistency_level, | |||
| using=self.alias, | |||
| ) | |||
| except MilvusException as e: | |||
| logger.error( | |||
| "Failed to create collection: %s error: %s", self.collection_name, e | |||
| ) | |||
| raise e | |||
| def _extract_fields(self) -> None: | |||
| """Grab the existing fields from the Collection""" | |||
| from pymilvus import Collection | |||
| if isinstance(self.col, Collection): | |||
| schema = self.col.schema | |||
| for x in schema.fields: | |||
| self.fields.append(x.name) | |||
| # Since primary field is auto-id, no need to track it | |||
| self.fields.remove(self._primary_field) | |||
| def _get_index(self) -> Optional[dict[str, Any]]: | |||
| """Return the vector index information if it exists""" | |||
| from pymilvus import Collection | |||
| if isinstance(self.col, Collection): | |||
| for x in self.col.indexes: | |||
| if x.field_name == self._vector_field: | |||
| return x.to_dict() | |||
| return None | |||
| def _create_index(self) -> None: | |||
| """Create a index on the collection""" | |||
| from pymilvus import Collection, MilvusException | |||
| if isinstance(self.col, Collection) and self._get_index() is None: | |||
| try: | |||
| # If no index params, use a default HNSW based one | |||
| if self.index_params is None: | |||
| self.index_params = { | |||
| "metric_type": "IP", | |||
| "index_type": "HNSW", | |||
| "params": {"M": 8, "efConstruction": 64}, | |||
| } | |||
| try: | |||
| self.col.create_index( | |||
| self._vector_field, | |||
| index_params=self.index_params, | |||
| using=self.alias, | |||
| ) | |||
| # If default did not work, most likely on Zilliz Cloud | |||
| except MilvusException: | |||
| # Use AUTOINDEX based index | |||
| self.index_params = { | |||
| "metric_type": "L2", | |||
| "index_type": "AUTOINDEX", | |||
| "params": {}, | |||
| } | |||
| self.col.create_index( | |||
| self._vector_field, | |||
| index_params=self.index_params, | |||
| using=self.alias, | |||
| ) | |||
| logger.debug( | |||
| "Successfully created an index on collection: %s", | |||
| self.collection_name, | |||
| ) | |||
| except MilvusException as e: | |||
| logger.error( | |||
| "Failed to create an index on collection: %s", self.collection_name | |||
| ) | |||
| raise e | |||
| def _create_search_params(self) -> None: | |||
| """Generate search params based on the current index type""" | |||
| from pymilvus import Collection | |||
| if isinstance(self.col, Collection) and self.search_params is None: | |||
| index = self._get_index() | |||
| if index is not None: | |||
| index_type: str = index["index_param"]["index_type"] | |||
| metric_type: str = index["index_param"]["metric_type"] | |||
| self.search_params = self.default_search_params[index_type] | |||
| self.search_params["metric_type"] = metric_type | |||
| def _load(self) -> None: | |||
| """Load the collection if available.""" | |||
| from pymilvus import Collection | |||
| if isinstance(self.col, Collection) and self._get_index() is not None: | |||
| self.col.load() | |||
| def add_texts( | |||
| self, | |||
| texts: Iterable[str], | |||
| metadatas: Optional[list[dict]] = None, | |||
| timeout: Optional[int] = None, | |||
| batch_size: int = 1000, | |||
| **kwargs: Any, | |||
| ) -> list[str]: | |||
| """Insert text data into Milvus. | |||
| Inserting data when the collection has not be made yet will result | |||
| in creating a new Collection. The data of the first entity decides | |||
| the schema of the new collection, the dim is extracted from the first | |||
| embedding and the columns are decided by the first metadata dict. | |||
| Metada keys will need to be present for all inserted values. At | |||
| the moment there is no None equivalent in Milvus. | |||
| Args: | |||
| texts (Iterable[str]): The texts to embed, it is assumed | |||
| that they all fit in memory. | |||
| metadatas (Optional[List[dict]]): Metadata dicts attached to each of | |||
| the texts. Defaults to None. | |||
| timeout (Optional[int]): Timeout for each batch insert. Defaults | |||
| to None. | |||
| batch_size (int, optional): Batch size to use for insertion. | |||
| Defaults to 1000. | |||
| Raises: | |||
| MilvusException: Failure to add texts | |||
| Returns: | |||
| List[str]: The resulting keys for each inserted element. | |||
| """ | |||
| from pymilvus import Collection, MilvusException | |||
| texts = list(texts) | |||
| try: | |||
| embeddings = self.embedding_func.embed_documents(texts) | |||
| except NotImplementedError: | |||
| embeddings = [self.embedding_func.embed_query(x) for x in texts] | |||
| if len(embeddings) == 0: | |||
| logger.debug("Nothing to insert, skipping.") | |||
| return [] | |||
| # If the collection hasn't been initialized yet, perform all steps to do so | |||
| if not isinstance(self.col, Collection): | |||
| self._init(embeddings, metadatas) | |||
| # Dict to hold all insert columns | |||
| insert_dict: dict[str, list] = { | |||
| self._text_field: texts, | |||
| self._vector_field: embeddings, | |||
| } | |||
| # Collect the metadata into the insert dict. | |||
| # if metadatas is not None: | |||
| # for d in metadatas: | |||
| # for key, value in d.items(): | |||
| # if key in self.fields: | |||
| # insert_dict.setdefault(key, []).append(value) | |||
| if metadatas is not None: | |||
| for d in metadatas: | |||
| insert_dict.setdefault(self._metadata_field, []).append(d) | |||
| # Total insert count | |||
| vectors: list = insert_dict[self._vector_field] | |||
| total_count = len(vectors) | |||
| pks: list[str] = [] | |||
| assert isinstance(self.col, Collection) | |||
| for i in range(0, total_count, batch_size): | |||
| # Grab end index | |||
| end = min(i + batch_size, total_count) | |||
| # Convert dict to list of lists batch for insertion | |||
| insert_list = [insert_dict[x][i:end] for x in self.fields] | |||
| # Insert into the collection. | |||
| try: | |||
| res: Collection | |||
| res = self.col.insert(insert_list, timeout=timeout, **kwargs) | |||
| pks.extend(res.primary_keys) | |||
| except MilvusException as e: | |||
| logger.error( | |||
| "Failed to insert batch starting at entity: %s/%s", i, total_count | |||
| ) | |||
| raise e | |||
| return pks | |||
| def similarity_search( | |||
| self, | |||
| query: str, | |||
| k: int = 4, | |||
| param: Optional[dict] = None, | |||
| expr: Optional[str] = None, | |||
| timeout: Optional[int] = None, | |||
| **kwargs: Any, | |||
| ) -> list[Document]: | |||
| """Perform a similarity search against the query string. | |||
| Args: | |||
| query (str): The text to search. | |||
| k (int, optional): How many results to return. Defaults to 4. | |||
| param (dict, optional): The search params for the index type. | |||
| Defaults to None. | |||
| expr (str, optional): Filtering expression. Defaults to None. | |||
| timeout (int, optional): How long to wait before timeout error. | |||
| Defaults to None. | |||
| kwargs: Collection.search() keyword arguments. | |||
| Returns: | |||
| List[Document]: Document results for search. | |||
| """ | |||
| if self.col is None: | |||
| logger.debug("No existing collection to search.") | |||
| return [] | |||
| res = self.similarity_search_with_score( | |||
| query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs | |||
| ) | |||
| return [doc for doc, _ in res] | |||
| def similarity_search_by_vector( | |||
| self, | |||
| embedding: list[float], | |||
| k: int = 4, | |||
| param: Optional[dict] = None, | |||
| expr: Optional[str] = None, | |||
| timeout: Optional[int] = None, | |||
| **kwargs: Any, | |||
| ) -> list[Document]: | |||
| """Perform a similarity search against the query string. | |||
| Args: | |||
| embedding (List[float]): The embedding vector to search. | |||
| k (int, optional): How many results to return. Defaults to 4. | |||
| param (dict, optional): The search params for the index type. | |||
| Defaults to None. | |||
| expr (str, optional): Filtering expression. Defaults to None. | |||
| timeout (int, optional): How long to wait before timeout error. | |||
| Defaults to None. | |||
| kwargs: Collection.search() keyword arguments. | |||
| Returns: | |||
| List[Document]: Document results for search. | |||
| """ | |||
| if self.col is None: | |||
| logger.debug("No existing collection to search.") | |||
| return [] | |||
| res = self.similarity_search_with_score_by_vector( | |||
| embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs | |||
| ) | |||
| return [doc for doc, _ in res] | |||
| def similarity_search_with_score( | |||
| self, | |||
| query: str, | |||
| k: int = 4, | |||
| param: Optional[dict] = None, | |||
| expr: Optional[str] = None, | |||
| timeout: Optional[int] = None, | |||
| **kwargs: Any, | |||
| ) -> list[tuple[Document, float]]: | |||
| """Perform a search on a query string and return results with score. | |||
| For more information about the search parameters, take a look at the pymilvus | |||
| documentation found here: | |||
| https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md | |||
| Args: | |||
| query (str): The text being searched. | |||
| k (int, optional): The amount of results to return. Defaults to 4. | |||
| param (dict): The search params for the specified index. | |||
| Defaults to None. | |||
| expr (str, optional): Filtering expression. Defaults to None. | |||
| timeout (int, optional): How long to wait before timeout error. | |||
| Defaults to None. | |||
| kwargs: Collection.search() keyword arguments. | |||
| Returns: | |||
| List[float], List[Tuple[Document, any, any]]: | |||
| """ | |||
| if self.col is None: | |||
| logger.debug("No existing collection to search.") | |||
| return [] | |||
| # Embed the query text. | |||
| embedding = self.embedding_func.embed_query(query) | |||
| res = self.similarity_search_with_score_by_vector( | |||
| embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs | |||
| ) | |||
| return res | |||
| def _similarity_search_with_relevance_scores( | |||
| self, | |||
| query: str, | |||
| k: int = 4, | |||
| **kwargs: Any, | |||
| ) -> list[tuple[Document, float]]: | |||
| """Return docs and relevance scores in the range [0, 1]. | |||
| 0 is dissimilar, 1 is most similar. | |||
| Args: | |||
| query: input text | |||
| k: Number of Documents to return. Defaults to 4. | |||
| **kwargs: kwargs to be passed to similarity search. Should include: | |||
| score_threshold: Optional, a floating point value between 0 to 1 to | |||
| filter the resulting set of retrieved docs | |||
| Returns: | |||
| List of Tuples of (doc, similarity_score) | |||
| """ | |||
| return self.similarity_search_with_score(query, k, **kwargs) | |||
| def similarity_search_with_score_by_vector( | |||
| self, | |||
| embedding: list[float], | |||
| k: int = 4, | |||
| param: Optional[dict] = None, | |||
| expr: Optional[str] = None, | |||
| timeout: Optional[int] = None, | |||
| **kwargs: Any, | |||
| ) -> list[tuple[Document, float]]: | |||
| """Perform a search on a query string and return results with score. | |||
| For more information about the search parameters, take a look at the pymilvus | |||
| documentation found here: | |||
| https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md | |||
| Args: | |||
| embedding (List[float]): The embedding vector being searched. | |||
| k (int, optional): The amount of results to return. Defaults to 4. | |||
| param (dict): The search params for the specified index. | |||
| Defaults to None. | |||
| expr (str, optional): Filtering expression. Defaults to None. | |||
| timeout (int, optional): How long to wait before timeout error. | |||
| Defaults to None. | |||
| kwargs: Collection.search() keyword arguments. | |||
| Returns: | |||
| List[Tuple[Document, float]]: Result doc and score. | |||
| """ | |||
| if self.col is None: | |||
| logger.debug("No existing collection to search.") | |||
| return [] | |||
| if param is None: | |||
| param = self.search_params | |||
| # Determine result metadata fields. | |||
| output_fields = self.fields[:] | |||
| output_fields.remove(self._vector_field) | |||
| # Perform the search. | |||
| res = self.col.search( | |||
| data=[embedding], | |||
| anns_field=self._vector_field, | |||
| param=param, | |||
| limit=k, | |||
| expr=expr, | |||
| output_fields=output_fields, | |||
| timeout=timeout, | |||
| **kwargs, | |||
| ) | |||
| # Organize results. | |||
| ret = [] | |||
| for result in res[0]: | |||
| meta = {x: result.entity.get(x) for x in output_fields} | |||
| doc = Document(page_content=meta.pop(self._text_field), metadata=meta.get('metadata')) | |||
| pair = (doc, result.score) | |||
| ret.append(pair) | |||
| return ret | |||
| def max_marginal_relevance_search( | |||
| self, | |||
| query: str, | |||
| k: int = 4, | |||
| fetch_k: int = 20, | |||
| lambda_mult: float = 0.5, | |||
| param: Optional[dict] = None, | |||
| expr: Optional[str] = None, | |||
| timeout: Optional[int] = None, | |||
| **kwargs: Any, | |||
| ) -> list[Document]: | |||
| """Perform a search and return results that are reordered by MMR. | |||
| Args: | |||
| query (str): The text being searched. | |||
| k (int, optional): How many results to give. Defaults to 4. | |||
| fetch_k (int, optional): Total results to select k from. | |||
| Defaults to 20. | |||
| lambda_mult: Number between 0 and 1 that determines the degree | |||
| of diversity among the results with 0 corresponding | |||
| to maximum diversity and 1 to minimum diversity. | |||
| Defaults to 0.5 | |||
| param (dict, optional): The search params for the specified index. | |||
| Defaults to None. | |||
| expr (str, optional): Filtering expression. Defaults to None. | |||
| timeout (int, optional): How long to wait before timeout error. | |||
| Defaults to None. | |||
| kwargs: Collection.search() keyword arguments. | |||
| Returns: | |||
| List[Document]: Document results for search. | |||
| """ | |||
| if self.col is None: | |||
| logger.debug("No existing collection to search.") | |||
| return [] | |||
| embedding = self.embedding_func.embed_query(query) | |||
| return self.max_marginal_relevance_search_by_vector( | |||
| embedding=embedding, | |||
| k=k, | |||
| fetch_k=fetch_k, | |||
| lambda_mult=lambda_mult, | |||
| param=param, | |||
| expr=expr, | |||
| timeout=timeout, | |||
| **kwargs, | |||
| ) | |||
| def max_marginal_relevance_search_by_vector( | |||
| self, | |||
| embedding: list[float], | |||
| k: int = 4, | |||
| fetch_k: int = 20, | |||
| lambda_mult: float = 0.5, | |||
| param: Optional[dict] = None, | |||
| expr: Optional[str] = None, | |||
| timeout: Optional[int] = None, | |||
| **kwargs: Any, | |||
| ) -> list[Document]: | |||
| """Perform a search and return results that are reordered by MMR. | |||
| Args: | |||
| embedding (str): The embedding vector being searched. | |||
| k (int, optional): How many results to give. Defaults to 4. | |||
| fetch_k (int, optional): Total results to select k from. | |||
| Defaults to 20. | |||
| lambda_mult: Number between 0 and 1 that determines the degree | |||
| of diversity among the results with 0 corresponding | |||
| to maximum diversity and 1 to minimum diversity. | |||
| Defaults to 0.5 | |||
| param (dict, optional): The search params for the specified index. | |||
| Defaults to None. | |||
| expr (str, optional): Filtering expression. Defaults to None. | |||
| timeout (int, optional): How long to wait before timeout error. | |||
| Defaults to None. | |||
| kwargs: Collection.search() keyword arguments. | |||
| Returns: | |||
| List[Document]: Document results for search. | |||
| """ | |||
| if self.col is None: | |||
| logger.debug("No existing collection to search.") | |||
| return [] | |||
| if param is None: | |||
| param = self.search_params | |||
| # Determine result metadata fields. | |||
| output_fields = self.fields[:] | |||
| output_fields.remove(self._vector_field) | |||
| # Perform the search. | |||
| res = self.col.search( | |||
| data=[embedding], | |||
| anns_field=self._vector_field, | |||
| param=param, | |||
| limit=fetch_k, | |||
| expr=expr, | |||
| output_fields=output_fields, | |||
| timeout=timeout, | |||
| **kwargs, | |||
| ) | |||
| # Organize results. | |||
| ids = [] | |||
| documents = [] | |||
| scores = [] | |||
| for result in res[0]: | |||
| meta = {x: result.entity.get(x) for x in output_fields} | |||
| doc = Document(page_content=meta.pop(self._text_field), metadata=meta) | |||
| documents.append(doc) | |||
| scores.append(result.score) | |||
| ids.append(result.id) | |||
| vectors = self.col.query( | |||
| expr=f"{self._primary_field} in {ids}", | |||
| output_fields=[self._primary_field, self._vector_field], | |||
| timeout=timeout, | |||
| ) | |||
| # Reorganize the results from query to match search order. | |||
| vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors} | |||
| ordered_result_embeddings = [vectors[x] for x in ids] | |||
| # Get the new order of results. | |||
| new_ordering = maximal_marginal_relevance( | |||
| np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult | |||
| ) | |||
| # Reorder the values and return. | |||
| ret = [] | |||
| for x in new_ordering: | |||
| # Function can return -1 index | |||
| if x == -1: | |||
| break | |||
| else: | |||
| ret.append(documents[x]) | |||
| return ret | |||
| @classmethod | |||
| def from_texts( | |||
| cls, | |||
| texts: list[str], | |||
| embedding: Embeddings, | |||
| metadatas: Optional[list[dict]] = None, | |||
| collection_name: str = "LangChainCollection", | |||
| connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION, | |||
| consistency_level: str = "Session", | |||
| index_params: Optional[dict] = None, | |||
| search_params: Optional[dict] = None, | |||
| drop_old: bool = False, | |||
| batch_size: int = 100, | |||
| ids: Optional[Sequence[str]] = None, | |||
| **kwargs: Any, | |||
| ) -> Milvus: | |||
| """Create a Milvus collection, indexes it with HNSW, and insert data. | |||
| Args: | |||
| texts (List[str]): Text data. | |||
| embedding (Embeddings): Embedding function. | |||
| metadatas (Optional[List[dict]]): Metadata for each text if it exists. | |||
| Defaults to None. | |||
| collection_name (str, optional): Collection name to use. Defaults to | |||
| "LangChainCollection". | |||
| connection_args (dict[str, Any], optional): Connection args to use. Defaults | |||
| to DEFAULT_MILVUS_CONNECTION. | |||
| consistency_level (str, optional): Which consistency level to use. Defaults | |||
| to "Session". | |||
| index_params (Optional[dict], optional): Which index_params to use. Defaults | |||
| to None. | |||
| search_params (Optional[dict], optional): Which search params to use. | |||
| Defaults to None. | |||
| drop_old (Optional[bool], optional): Whether to drop the collection with | |||
| that name if it exists. Defaults to False. | |||
| batch_size: | |||
| How many vectors upload per-request. | |||
| Default: 100 | |||
| ids: Optional[Sequence[str]] = None, | |||
| Returns: | |||
| Milvus: Milvus Vector Store | |||
| """ | |||
| vector_db = cls( | |||
| embedding_function=embedding, | |||
| collection_name=collection_name, | |||
| connection_args=connection_args, | |||
| consistency_level=consistency_level, | |||
| index_params=index_params, | |||
| search_params=search_params, | |||
| drop_old=drop_old, | |||
| **kwargs, | |||
| ) | |||
| vector_db.add_texts(texts=texts, metadatas=metadatas, batch_size=batch_size) | |||
| return vector_db | |||
| @@ -1,506 +0,0 @@ | |||
| """Wrapper around weaviate vector database.""" | |||
| from __future__ import annotations | |||
| import datetime | |||
| from collections.abc import Callable, Iterable | |||
| from typing import Any, Optional | |||
| from uuid import uuid4 | |||
| import numpy as np | |||
| from langchain.docstore.document import Document | |||
| from langchain.embeddings.base import Embeddings | |||
| from langchain.utils import get_from_dict_or_env | |||
| from langchain.vectorstores.base import VectorStore | |||
| from langchain.vectorstores.utils import maximal_marginal_relevance | |||
| def _default_schema(index_name: str) -> dict: | |||
| return { | |||
| "class": index_name, | |||
| "properties": [ | |||
| { | |||
| "name": "text", | |||
| "dataType": ["text"], | |||
| } | |||
| ], | |||
| } | |||
| def _create_weaviate_client(**kwargs: Any) -> Any: | |||
| client = kwargs.get("client") | |||
| if client is not None: | |||
| return client | |||
| weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL") | |||
| try: | |||
| # the weaviate api key param should not be mandatory | |||
| weaviate_api_key = get_from_dict_or_env( | |||
| kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None | |||
| ) | |||
| except ValueError: | |||
| weaviate_api_key = None | |||
| try: | |||
| import weaviate | |||
| except ImportError: | |||
| raise ValueError( | |||
| "Could not import weaviate python package. " | |||
| "Please install it with `pip install weaviate-client`" | |||
| ) | |||
| auth = ( | |||
| weaviate.auth.AuthApiKey(api_key=weaviate_api_key) | |||
| if weaviate_api_key is not None | |||
| else None | |||
| ) | |||
| client = weaviate.Client(weaviate_url, auth_client_secret=auth) | |||
| return client | |||
| def _default_score_normalizer(val: float) -> float: | |||
| return 1 - val | |||
| def _json_serializable(value: Any) -> Any: | |||
| if isinstance(value, datetime.datetime): | |||
| return value.isoformat() | |||
| return value | |||
| class Weaviate(VectorStore): | |||
| """Wrapper around Weaviate vector database. | |||
| To use, you should have the ``weaviate-client`` python package installed. | |||
| Example: | |||
| .. code-block:: python | |||
| import weaviate | |||
| from langchain.vectorstores import Weaviate | |||
| client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...) | |||
| weaviate = Weaviate(client, index_name, text_key) | |||
| """ | |||
| def __init__( | |||
| self, | |||
| client: Any, | |||
| index_name: str, | |||
| text_key: str, | |||
| embedding: Optional[Embeddings] = None, | |||
| attributes: Optional[list[str]] = None, | |||
| relevance_score_fn: Optional[ | |||
| Callable[[float], float] | |||
| ] = _default_score_normalizer, | |||
| by_text: bool = True, | |||
| ): | |||
| """Initialize with Weaviate client.""" | |||
| try: | |||
| import weaviate | |||
| except ImportError: | |||
| raise ValueError( | |||
| "Could not import weaviate python package. " | |||
| "Please install it with `pip install weaviate-client`." | |||
| ) | |||
| if not isinstance(client, weaviate.Client): | |||
| raise ValueError( | |||
| f"client should be an instance of weaviate.Client, got {type(client)}" | |||
| ) | |||
| self._client = client | |||
| self._index_name = index_name | |||
| self._embedding = embedding | |||
| self._text_key = text_key | |||
| self._query_attrs = [self._text_key] | |||
| self.relevance_score_fn = relevance_score_fn | |||
| self._by_text = by_text | |||
| if attributes is not None: | |||
| self._query_attrs.extend(attributes) | |||
| @property | |||
| def embeddings(self) -> Optional[Embeddings]: | |||
| return self._embedding | |||
| def _select_relevance_score_fn(self) -> Callable[[float], float]: | |||
| return ( | |||
| self.relevance_score_fn | |||
| if self.relevance_score_fn | |||
| else _default_score_normalizer | |||
| ) | |||
| def add_texts( | |||
| self, | |||
| texts: Iterable[str], | |||
| metadatas: Optional[list[dict]] = None, | |||
| **kwargs: Any, | |||
| ) -> list[str]: | |||
| """Upload texts with metadata (properties) to Weaviate.""" | |||
| from weaviate.util import get_valid_uuid | |||
| ids = [] | |||
| embeddings: Optional[list[list[float]]] = None | |||
| if self._embedding: | |||
| if not isinstance(texts, list): | |||
| texts = list(texts) | |||
| embeddings = self._embedding.embed_documents(texts) | |||
| with self._client.batch as batch: | |||
| for i, text in enumerate(texts): | |||
| data_properties = {self._text_key: text} | |||
| if metadatas is not None: | |||
| for key, val in metadatas[i].items(): | |||
| data_properties[key] = _json_serializable(val) | |||
| # Allow for ids (consistent w/ other methods) | |||
| # # Or uuids (backwards compatble w/ existing arg) | |||
| # If the UUID of one of the objects already exists | |||
| # then the existing object will be replaced by the new object. | |||
| _id = get_valid_uuid(uuid4()) | |||
| if "uuids" in kwargs: | |||
| _id = kwargs["uuids"][i] | |||
| elif "ids" in kwargs: | |||
| _id = kwargs["ids"][i] | |||
| batch.add_data_object( | |||
| data_object=data_properties, | |||
| class_name=self._index_name, | |||
| uuid=_id, | |||
| vector=embeddings[i] if embeddings else None, | |||
| ) | |||
| ids.append(_id) | |||
| return ids | |||
| def similarity_search( | |||
| self, query: str, k: int = 4, **kwargs: Any | |||
| ) -> list[Document]: | |||
| """Return docs most similar to query. | |||
| Args: | |||
| query: Text to look up documents similar to. | |||
| k: Number of Documents to return. Defaults to 4. | |||
| Returns: | |||
| List of Documents most similar to the query. | |||
| """ | |||
| if self._by_text: | |||
| return self.similarity_search_by_text(query, k, **kwargs) | |||
| else: | |||
| if self._embedding is None: | |||
| raise ValueError( | |||
| "_embedding cannot be None for similarity_search when " | |||
| "_by_text=False" | |||
| ) | |||
| embedding = self._embedding.embed_query(query) | |||
| return self.similarity_search_by_vector(embedding, k, **kwargs) | |||
| def similarity_search_by_text( | |||
| self, query: str, k: int = 4, **kwargs: Any | |||
| ) -> list[Document]: | |||
| """Return docs most similar to query. | |||
| Args: | |||
| query: Text to look up documents similar to. | |||
| k: Number of Documents to return. Defaults to 4. | |||
| Returns: | |||
| List of Documents most similar to the query. | |||
| """ | |||
| content: dict[str, Any] = {"concepts": [query]} | |||
| if kwargs.get("search_distance"): | |||
| content["certainty"] = kwargs.get("search_distance") | |||
| query_obj = self._client.query.get(self._index_name, self._query_attrs) | |||
| if kwargs.get("where_filter"): | |||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||
| if kwargs.get("additional"): | |||
| query_obj = query_obj.with_additional(kwargs.get("additional")) | |||
| result = query_obj.with_near_text(content).with_limit(k).do() | |||
| if "errors" in result: | |||
| raise ValueError(f"Error during query: {result['errors']}") | |||
| docs = [] | |||
| for res in result["data"]["Get"][self._index_name]: | |||
| text = res.pop(self._text_key) | |||
| docs.append(Document(page_content=text, metadata=res)) | |||
| return docs | |||
| def similarity_search_by_bm25( | |||
| self, query: str, k: int = 4, **kwargs: Any | |||
| ) -> list[Document]: | |||
| """Return docs using BM25F. | |||
| Args: | |||
| query: Text to look up documents similar to. | |||
| k: Number of Documents to return. Defaults to 4. | |||
| Returns: | |||
| List of Documents most similar to the query. | |||
| """ | |||
| content: dict[str, Any] = {"concepts": [query]} | |||
| if kwargs.get("search_distance"): | |||
| content["certainty"] = kwargs.get("search_distance") | |||
| query_obj = self._client.query.get(self._index_name, self._query_attrs) | |||
| if kwargs.get("where_filter"): | |||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||
| if kwargs.get("additional"): | |||
| query_obj = query_obj.with_additional(kwargs.get("additional")) | |||
| properties = ['text'] | |||
| result = query_obj.with_bm25(query=query, properties=properties).with_limit(k).do() | |||
| if "errors" in result: | |||
| raise ValueError(f"Error during query: {result['errors']}") | |||
| docs = [] | |||
| for res in result["data"]["Get"][self._index_name]: | |||
| text = res.pop(self._text_key) | |||
| docs.append(Document(page_content=text, metadata=res)) | |||
| return docs | |||
| def similarity_search_by_vector( | |||
| self, embedding: list[float], k: int = 4, **kwargs: Any | |||
| ) -> list[Document]: | |||
| """Look up similar documents by embedding vector in Weaviate.""" | |||
| vector = {"vector": embedding} | |||
| query_obj = self._client.query.get(self._index_name, self._query_attrs) | |||
| if kwargs.get("where_filter"): | |||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||
| if kwargs.get("additional"): | |||
| query_obj = query_obj.with_additional(kwargs.get("additional")) | |||
| result = query_obj.with_near_vector(vector).with_limit(k).do() | |||
| if "errors" in result: | |||
| raise ValueError(f"Error during query: {result['errors']}") | |||
| docs = [] | |||
| for res in result["data"]["Get"][self._index_name]: | |||
| text = res.pop(self._text_key) | |||
| docs.append(Document(page_content=text, metadata=res)) | |||
| return docs | |||
| def max_marginal_relevance_search( | |||
| self, | |||
| query: str, | |||
| k: int = 4, | |||
| fetch_k: int = 20, | |||
| lambda_mult: float = 0.5, | |||
| **kwargs: Any, | |||
| ) -> list[Document]: | |||
| """Return docs selected using the maximal marginal relevance. | |||
| Maximal marginal relevance optimizes for similarity to query AND diversity | |||
| among selected documents. | |||
| Args: | |||
| query: Text to look up documents similar to. | |||
| k: Number of Documents to return. Defaults to 4. | |||
| fetch_k: Number of Documents to fetch to pass to MMR algorithm. | |||
| lambda_mult: Number between 0 and 1 that determines the degree | |||
| of diversity among the results with 0 corresponding | |||
| to maximum diversity and 1 to minimum diversity. | |||
| Defaults to 0.5. | |||
| Returns: | |||
| List of Documents selected by maximal marginal relevance. | |||
| """ | |||
| if self._embedding is not None: | |||
| embedding = self._embedding.embed_query(query) | |||
| else: | |||
| raise ValueError( | |||
| "max_marginal_relevance_search requires a suitable Embeddings object" | |||
| ) | |||
| return self.max_marginal_relevance_search_by_vector( | |||
| embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs | |||
| ) | |||
| def max_marginal_relevance_search_by_vector( | |||
| self, | |||
| embedding: list[float], | |||
| k: int = 4, | |||
| fetch_k: int = 20, | |||
| lambda_mult: float = 0.5, | |||
| **kwargs: Any, | |||
| ) -> list[Document]: | |||
| """Return docs selected using the maximal marginal relevance. | |||
| Maximal marginal relevance optimizes for similarity to query AND diversity | |||
| among selected documents. | |||
| Args: | |||
| embedding: Embedding to look up documents similar to. | |||
| k: Number of Documents to return. Defaults to 4. | |||
| fetch_k: Number of Documents to fetch to pass to MMR algorithm. | |||
| lambda_mult: Number between 0 and 1 that determines the degree | |||
| of diversity among the results with 0 corresponding | |||
| to maximum diversity and 1 to minimum diversity. | |||
| Defaults to 0.5. | |||
| Returns: | |||
| List of Documents selected by maximal marginal relevance. | |||
| """ | |||
| vector = {"vector": embedding} | |||
| query_obj = self._client.query.get(self._index_name, self._query_attrs) | |||
| if kwargs.get("where_filter"): | |||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||
| results = ( | |||
| query_obj.with_additional("vector") | |||
| .with_near_vector(vector) | |||
| .with_limit(fetch_k) | |||
| .do() | |||
| ) | |||
| payload = results["data"]["Get"][self._index_name] | |||
| embeddings = [result["_additional"]["vector"] for result in payload] | |||
| mmr_selected = maximal_marginal_relevance( | |||
| np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult | |||
| ) | |||
| docs = [] | |||
| for idx in mmr_selected: | |||
| text = payload[idx].pop(self._text_key) | |||
| payload[idx].pop("_additional") | |||
| meta = payload[idx] | |||
| docs.append(Document(page_content=text, metadata=meta)) | |||
| return docs | |||
| def similarity_search_with_score( | |||
| self, query: str, k: int = 4, **kwargs: Any | |||
| ) -> list[tuple[Document, float]]: | |||
| """ | |||
| Return list of documents most similar to the query | |||
| text and cosine distance in float for each. | |||
| Lower score represents more similarity. | |||
| """ | |||
| if self._embedding is None: | |||
| raise ValueError( | |||
| "_embedding cannot be None for similarity_search_with_score" | |||
| ) | |||
| content: dict[str, Any] = {"concepts": [query]} | |||
| if kwargs.get("search_distance"): | |||
| content["certainty"] = kwargs.get("search_distance") | |||
| query_obj = self._client.query.get(self._index_name, self._query_attrs) | |||
| embedded_query = self._embedding.embed_query(query) | |||
| if not self._by_text: | |||
| vector = {"vector": embedded_query} | |||
| result = ( | |||
| query_obj.with_near_vector(vector) | |||
| .with_limit(k) | |||
| .with_additional(["vector", "distance"]) | |||
| .do() | |||
| ) | |||
| else: | |||
| result = ( | |||
| query_obj.with_near_text(content) | |||
| .with_limit(k) | |||
| .with_additional(["vector", "distance"]) | |||
| .do() | |||
| ) | |||
| if "errors" in result: | |||
| raise ValueError(f"Error during query: {result['errors']}") | |||
| docs_and_scores = [] | |||
| for res in result["data"]["Get"][self._index_name]: | |||
| text = res.pop(self._text_key) | |||
| score = res["_additional"]["distance"] | |||
| docs_and_scores.append((Document(page_content=text, metadata=res), score)) | |||
| return docs_and_scores | |||
| @classmethod | |||
| def from_texts( | |||
| cls: type[Weaviate], | |||
| texts: list[str], | |||
| embedding: Embeddings, | |||
| metadatas: Optional[list[dict]] = None, | |||
| **kwargs: Any, | |||
| ) -> Weaviate: | |||
| """Construct Weaviate wrapper from raw documents. | |||
| This is a user-friendly interface that: | |||
| 1. Embeds documents. | |||
| 2. Creates a new index for the embeddings in the Weaviate instance. | |||
| 3. Adds the documents to the newly created Weaviate index. | |||
| This is intended to be a quick way to get started. | |||
| Example: | |||
| .. code-block:: python | |||
| from langchain.vectorstores.weaviate import Weaviate | |||
| from langchain.embeddings import OpenAIEmbeddings | |||
| embeddings = OpenAIEmbeddings() | |||
| weaviate = Weaviate.from_texts( | |||
| texts, | |||
| embeddings, | |||
| weaviate_url="http://localhost:8080" | |||
| ) | |||
| """ | |||
| client = _create_weaviate_client(**kwargs) | |||
| from weaviate.util import get_valid_uuid | |||
| index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}") | |||
| embeddings = embedding.embed_documents(texts) if embedding else None | |||
| text_key = "text" | |||
| schema = _default_schema(index_name) | |||
| attributes = list(metadatas[0].keys()) if metadatas else None | |||
| # check whether the index already exists | |||
| if not client.schema.contains(schema): | |||
| client.schema.create_class(schema) | |||
| with client.batch as batch: | |||
| for i, text in enumerate(texts): | |||
| data_properties = { | |||
| text_key: text, | |||
| } | |||
| if metadatas is not None: | |||
| for key in metadatas[i].keys(): | |||
| data_properties[key] = metadatas[i][key] | |||
| # If the UUID of one of the objects already exists | |||
| # then the existing objectwill be replaced by the new object. | |||
| if "uuids" in kwargs: | |||
| _id = kwargs["uuids"][i] | |||
| else: | |||
| _id = get_valid_uuid(uuid4()) | |||
| # if an embedding strategy is not provided, we let | |||
| # weaviate create the embedding. Note that this will only | |||
| # work if weaviate has been installed with a vectorizer module | |||
| # like text2vec-contextionary for example | |||
| params = { | |||
| "uuid": _id, | |||
| "data_object": data_properties, | |||
| "class_name": index_name, | |||
| } | |||
| if embeddings is not None: | |||
| params["vector"] = embeddings[i] | |||
| batch.add_data_object(**params) | |||
| batch.flush() | |||
| relevance_score_fn = kwargs.get("relevance_score_fn") | |||
| by_text: bool = kwargs.get("by_text", False) | |||
| return cls( | |||
| client, | |||
| index_name, | |||
| text_key, | |||
| embedding=embedding, | |||
| attributes=attributes, | |||
| relevance_score_fn=relevance_score_fn, | |||
| by_text=by_text, | |||
| ) | |||
| def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> None: | |||
| """Delete by vector IDs. | |||
| Args: | |||
| ids: List of ids to delete. | |||
| """ | |||
| if ids is None: | |||
| raise ValueError("No ids provided to delete.") | |||
| # TODO: Check if this can be done in bulk | |||
| for id in ids: | |||
| self._client.data_object.delete(uuid=id) | |||
| @@ -1,38 +0,0 @@ | |||
| from core.vector_store.vector.weaviate import Weaviate | |||
| class WeaviateVectorStore(Weaviate): | |||
| def del_texts(self, where_filter: dict): | |||
| if not where_filter: | |||
| raise ValueError('where_filter must not be empty') | |||
| self._client.batch.delete_objects( | |||
| class_name=self._index_name, | |||
| where=where_filter, | |||
| output='minimal' | |||
| ) | |||
| def del_text(self, uuid: str) -> None: | |||
| self._client.data_object.delete( | |||
| uuid, | |||
| class_name=self._index_name | |||
| ) | |||
| def text_exists(self, uuid: str) -> bool: | |||
| result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({ | |||
| "path": ["doc_id"], | |||
| "operator": "Equal", | |||
| "valueText": uuid, | |||
| }).with_limit(1).do() | |||
| if "errors" in result: | |||
| raise ValueError(f"Error during query: {result['errors']}") | |||
| entries = result["data"]["Get"][self._index_name] | |||
| if len(entries) == 0: | |||
| return False | |||
| return True | |||
| def delete(self): | |||
| self._client.schema.delete_class(self._index_name) | |||
| @@ -6,4 +6,4 @@ from tasks.clean_dataset_task import clean_dataset_task | |||
| def handle(sender, **kwargs): | |||
| dataset = sender | |||
| clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, | |||
| dataset.index_struct, dataset.collection_binding_id) | |||
| dataset.index_struct, dataset.collection_binding_id, dataset.doc_form) | |||
| @@ -6,4 +6,5 @@ from tasks.clean_document_task import clean_document_task | |||
| def handle(sender, **kwargs): | |||
| document_id = sender | |||
| dataset_id = kwargs.get('dataset_id') | |||
| clean_document_task.delay(document_id, dataset_id) | |||
| doc_form = kwargs.get('doc_form') | |||
| clean_document_task.delay(document_id, dataset_id, doc_form) | |||
| @@ -94,6 +94,14 @@ class Dataset(db.Model): | |||
| return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \ | |||
| .filter(Document.dataset_id == self.id).scalar() | |||
| @property | |||
| def doc_form(self): | |||
| document = db.session.query(Document).filter( | |||
| Document.dataset_id == self.id).first() | |||
| if document: | |||
| return document.doc_form | |||
| return None | |||
| @property | |||
| def retrieval_model_dict(self): | |||
| default_retrieval_model = { | |||
| @@ -6,7 +6,7 @@ from flask import current_app | |||
| from werkzeug.exceptions import NotFound | |||
| import app | |||
| from core.index.index import IndexBuilder | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, DatasetQuery, Document | |||
| @@ -41,18 +41,9 @@ def clean_unused_datasets_task(): | |||
| if not documents or len(documents) == 0: | |||
| try: | |||
| # remove index | |||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| kw_index = IndexBuilder.get_index(dataset, 'economy') | |||
| # delete from vector index | |||
| if vector_index: | |||
| if dataset.collection_binding_id: | |||
| vector_index.delete_by_group_id(dataset.id) | |||
| else: | |||
| if dataset.collection_binding_id: | |||
| vector_index.delete_by_group_id(dataset.id) | |||
| else: | |||
| vector_index.delete() | |||
| kw_index.delete() | |||
| index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() | |||
| index_processor.clean(dataset, None) | |||
| # update document | |||
| update_params = { | |||
| Document.enabled: False | |||
| @@ -11,10 +11,11 @@ from flask_login import current_user | |||
| from sqlalchemy import func | |||
| from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | |||
| from core.index.index import IndexBuilder | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | |||
| from core.rag.datasource.keyword.keyword_factory import Keyword | |||
| from core.rag.models.document import Document as RAGDocument | |||
| from events.dataset_event import dataset_was_deleted | |||
| from events.document_event import document_was_deleted | |||
| from extensions.ext_database import db | |||
| @@ -402,7 +403,7 @@ class DocumentService: | |||
| @staticmethod | |||
| def delete_document(document): | |||
| # trigger document_was_deleted signal | |||
| document_was_deleted.send(document.id, dataset_id=document.dataset_id) | |||
| document_was_deleted.send(document.id, dataset_id=document.dataset_id, doc_form=document.doc_form) | |||
| db.session.delete(document) | |||
| db.session.commit() | |||
| @@ -1060,7 +1061,7 @@ class SegmentService: | |||
| # save vector index | |||
| try: | |||
| VectorService.create_segment_vector(args['keywords'], segment_document, dataset) | |||
| VectorService.create_segments_vector([args['keywords']], [segment_document], dataset) | |||
| except Exception as e: | |||
| logging.exception("create segment index failed") | |||
| segment_document.enabled = False | |||
| @@ -1087,6 +1088,7 @@ class SegmentService: | |||
| ).scalar() | |||
| pre_segment_data_list = [] | |||
| segment_data_list = [] | |||
| keywords_list = [] | |||
| for segment_item in segments: | |||
| content = segment_item['content'] | |||
| doc_id = str(uuid.uuid4()) | |||
| @@ -1119,15 +1121,13 @@ class SegmentService: | |||
| segment_document.answer = segment_item['answer'] | |||
| db.session.add(segment_document) | |||
| segment_data_list.append(segment_document) | |||
| pre_segment_data = { | |||
| 'segment': segment_document, | |||
| 'keywords': segment_item['keywords'] | |||
| } | |||
| pre_segment_data_list.append(pre_segment_data) | |||
| pre_segment_data_list.append(segment_document) | |||
| keywords_list.append(segment_item['keywords']) | |||
| try: | |||
| # save vector index | |||
| VectorService.multi_create_segment_vector(pre_segment_data_list, dataset) | |||
| VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset) | |||
| except Exception as e: | |||
| logging.exception("create segment index failed") | |||
| for segment_document in segment_data_list: | |||
| @@ -1157,11 +1157,18 @@ class SegmentService: | |||
| db.session.commit() | |||
| # update segment index task | |||
| if args['keywords']: | |||
| kw_index = IndexBuilder.get_index(dataset, 'economy') | |||
| # delete from keyword index | |||
| kw_index.delete_by_ids([segment.index_node_id]) | |||
| # save keyword index | |||
| kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords) | |||
| keyword = Keyword(dataset) | |||
| keyword.delete_by_ids([segment.index_node_id]) | |||
| document = RAGDocument( | |||
| page_content=segment.content, | |||
| metadata={ | |||
| "doc_id": segment.index_node_id, | |||
| "doc_hash": segment.index_node_hash, | |||
| "document_id": segment.document_id, | |||
| "dataset_id": segment.dataset_id, | |||
| } | |||
| ) | |||
| keyword.add_texts([document], keywords_list=[args['keywords']]) | |||
| else: | |||
| segment_hash = helper.generate_text_hash(content) | |||
| tokens = 0 | |||
| @@ -9,8 +9,8 @@ from flask_login import current_user | |||
| from werkzeug.datastructures import FileStorage | |||
| from werkzeug.exceptions import NotFound | |||
| from core.data_loader.file_extractor import FileExtractor | |||
| from core.file.upload_file_parser import UploadFileParser | |||
| from core.rag.extractor.extract_processor import ExtractProcessor | |||
| from extensions.ext_database import db | |||
| from extensions.ext_storage import storage | |||
| from models.account import Account | |||
| @@ -32,7 +32,8 @@ class FileService: | |||
| def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile: | |||
| extension = file.filename.split('.')[-1] | |||
| etl_type = current_app.config['ETL_TYPE'] | |||
| allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS | |||
| allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if etl_type == 'Unstructured' \ | |||
| else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS | |||
| if extension.lower() not in allowed_extensions: | |||
| raise UnsupportedFileTypeError() | |||
| elif only_image and extension.lower() not in IMAGE_EXTENSIONS: | |||
| @@ -136,7 +137,7 @@ class FileService: | |||
| if extension.lower() not in allowed_extensions: | |||
| raise UnsupportedFileTypeError() | |||
| text = FileExtractor.load(upload_file, return_text=True) | |||
| text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True) | |||
| text = text[0:PREVIEW_WORDS_LIMIT] if text else '' | |||
| return text | |||
| @@ -164,7 +165,7 @@ class FileService: | |||
| return generator, upload_file.mime_type | |||
| @staticmethod | |||
| def get_public_image_preview(file_id: str) -> str: | |||
| def get_public_image_preview(file_id: str) -> tuple[Generator, str]: | |||
| upload_file = db.session.query(UploadFile) \ | |||
| .filter(UploadFile.id == file_id) \ | |||
| .first() | |||
| @@ -1,21 +1,18 @@ | |||
| import logging | |||
| import threading | |||
| import time | |||
| import numpy as np | |||
| from flask import current_app | |||
| from langchain.embeddings.base import Embeddings | |||
| from langchain.schema import Document | |||
| from sklearn.manifold import TSNE | |||
| from core.embedding.cached_embedding import CacheEmbedding | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.rerank.rerank import RerankRunner | |||
| from core.rag.datasource.entity.embedding import Embeddings | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rag.models.document import Document | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.dataset import Dataset, DatasetQuery, DocumentSegment | |||
| from services.retrieval_service import RetrievalService | |||
| default_retrieval_model = { | |||
| 'search_method': 'semantic_search', | |||
| @@ -28,6 +25,7 @@ default_retrieval_model = { | |||
| 'score_threshold_enabled': False | |||
| } | |||
| class HitTestingService: | |||
| @classmethod | |||
| def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict: | |||
| @@ -57,61 +55,15 @@ class HitTestingService: | |||
| embeddings = CacheEmbedding(embedding_model) | |||
| all_documents = [] | |||
| threads = [] | |||
| # retrieval_model source with semantic | |||
| if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': | |||
| embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset_id': str(dataset.id), | |||
| 'query': query, | |||
| 'top_k': retrieval_model['top_k'], | |||
| 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None, | |||
| 'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None, | |||
| 'all_documents': all_documents, | |||
| 'search_method': retrieval_model['search_method'], | |||
| 'embeddings': embeddings | |||
| }) | |||
| threads.append(embedding_thread) | |||
| embedding_thread.start() | |||
| # retrieval source with full text | |||
| if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': | |||
| full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset_id': str(dataset.id), | |||
| 'query': query, | |||
| 'search_method': retrieval_model['search_method'], | |||
| 'embeddings': embeddings, | |||
| 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None, | |||
| 'top_k': retrieval_model['top_k'], | |||
| 'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None, | |||
| 'all_documents': all_documents | |||
| }) | |||
| threads.append(full_text_index_thread) | |||
| full_text_index_thread.start() | |||
| for thread in threads: | |||
| thread.join() | |||
| if retrieval_model['search_method'] == 'hybrid_search': | |||
| model_manager = ModelManager() | |||
| rerank_model_instance = model_manager.get_model_instance( | |||
| tenant_id=dataset.tenant_id, | |||
| provider=retrieval_model['reranking_model']['reranking_provider_name'], | |||
| model_type=ModelType.RERANK, | |||
| model=retrieval_model['reranking_model']['reranking_model_name'] | |||
| ) | |||
| rerank_runner = RerankRunner(rerank_model_instance) | |||
| all_documents = rerank_runner.run( | |||
| query=query, | |||
| documents=all_documents, | |||
| score_threshold=retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None, | |||
| top_n=retrieval_model['top_k'], | |||
| user=f"account-{account.id}" | |||
| ) | |||
| all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], | |||
| dataset_id=dataset.id, | |||
| query=query, | |||
| top_k=retrieval_model['top_k'], | |||
| score_threshold=retrieval_model['score_threshold'] | |||
| if retrieval_model['score_threshold_enabled'] else None, | |||
| reranking_model=retrieval_model['reranking_model'] | |||
| if retrieval_model['reranking_enable'] else None | |||
| ) | |||
| end = time.perf_counter() | |||
| logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") | |||
| @@ -203,4 +155,3 @@ class HitTestingService: | |||
| if not query or len(query) > 250: | |||
| raise ValueError('Query is required and cannot exceed 250 characters') | |||
| @@ -1,119 +0,0 @@ | |||
| from typing import Optional | |||
| from flask import Flask, current_app | |||
| from langchain.embeddings.base import Embeddings | |||
| from core.index.vector_index.vector_index import VectorIndex | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from core.rerank.rerank import RerankRunner | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset | |||
| default_retrieval_model = { | |||
| 'search_method': 'semantic_search', | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| 'reranking_model_name': '' | |||
| }, | |||
| 'top_k': 2, | |||
| 'score_threshold_enabled': False | |||
| } | |||
| class RetrievalService: | |||
| @classmethod | |||
| def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, | |||
| top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], | |||
| all_documents: list, search_method: str, embeddings: Embeddings): | |||
| with flask_app.app_context(): | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| vector_index = VectorIndex( | |||
| dataset=dataset, | |||
| config=current_app.config, | |||
| embeddings=embeddings | |||
| ) | |||
| documents = vector_index.search( | |||
| query, | |||
| search_type='similarity_score_threshold', | |||
| search_kwargs={ | |||
| 'k': top_k, | |||
| 'score_threshold': score_threshold, | |||
| 'filter': { | |||
| 'group_id': [dataset.id] | |||
| } | |||
| } | |||
| ) | |||
| if documents: | |||
| if reranking_model and search_method == 'semantic_search': | |||
| try: | |||
| model_manager = ModelManager() | |||
| rerank_model_instance = model_manager.get_model_instance( | |||
| tenant_id=dataset.tenant_id, | |||
| provider=reranking_model['reranking_provider_name'], | |||
| model_type=ModelType.RERANK, | |||
| model=reranking_model['reranking_model_name'] | |||
| ) | |||
| except InvokeAuthorizationError: | |||
| return | |||
| rerank_runner = RerankRunner(rerank_model_instance) | |||
| all_documents.extend(rerank_runner.run( | |||
| query=query, | |||
| documents=documents, | |||
| score_threshold=score_threshold, | |||
| top_n=len(documents) | |||
| )) | |||
| else: | |||
| all_documents.extend(documents) | |||
| @classmethod | |||
| def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, | |||
| top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], | |||
| all_documents: list, search_method: str, embeddings: Embeddings): | |||
| with flask_app.app_context(): | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| vector_index = VectorIndex( | |||
| dataset=dataset, | |||
| config=current_app.config, | |||
| embeddings=embeddings | |||
| ) | |||
| documents = vector_index.search_by_full_text_index( | |||
| query, | |||
| search_type='similarity_score_threshold', | |||
| top_k=top_k | |||
| ) | |||
| if documents: | |||
| if reranking_model and search_method == 'full_text_search': | |||
| try: | |||
| model_manager = ModelManager() | |||
| rerank_model_instance = model_manager.get_model_instance( | |||
| tenant_id=dataset.tenant_id, | |||
| provider=reranking_model['reranking_provider_name'], | |||
| model_type=ModelType.RERANK, | |||
| model=reranking_model['reranking_model_name'] | |||
| ) | |||
| except InvokeAuthorizationError: | |||
| return | |||
| rerank_runner = RerankRunner(rerank_model_instance) | |||
| all_documents.extend(rerank_runner.run( | |||
| query=query, | |||
| documents=documents, | |||
| score_threshold=score_threshold, | |||
| top_n=len(documents) | |||
| )) | |||
| else: | |||
| all_documents.extend(documents) | |||
| @@ -1,44 +1,18 @@ | |||
| from typing import Optional | |||
| from langchain.schema import Document | |||
| from core.index.index import IndexBuilder | |||
| from core.rag.datasource.keyword.keyword_factory import Keyword | |||
| from core.rag.datasource.vdb.vector_factory import Vector | |||
| from core.rag.models.document import Document | |||
| from models.dataset import Dataset, DocumentSegment | |||
| class VectorService: | |||
| @classmethod | |||
| def create_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): | |||
| document = Document( | |||
| page_content=segment.content, | |||
| metadata={ | |||
| "doc_id": segment.index_node_id, | |||
| "doc_hash": segment.index_node_hash, | |||
| "document_id": segment.document_id, | |||
| "dataset_id": segment.dataset_id, | |||
| } | |||
| ) | |||
| # save vector index | |||
| index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| if index: | |||
| index.add_texts([document], duplicate_check=True) | |||
| # save keyword index | |||
| index = IndexBuilder.get_index(dataset, 'economy') | |||
| if index: | |||
| if keywords and len(keywords) > 0: | |||
| index.create_segment_keywords(segment.index_node_id, keywords) | |||
| else: | |||
| index.add_texts([document]) | |||
| @classmethod | |||
| def multi_create_segment_vector(cls, pre_segment_data_list: list, dataset: Dataset): | |||
| def create_segments_vector(cls, keywords_list: Optional[list[list[str]]], | |||
| segments: list[DocumentSegment], dataset: Dataset): | |||
| documents = [] | |||
| for pre_segment_data in pre_segment_data_list: | |||
| segment = pre_segment_data['segment'] | |||
| for segment in segments: | |||
| document = Document( | |||
| page_content=segment.content, | |||
| metadata={ | |||
| @@ -49,30 +23,26 @@ class VectorService: | |||
| } | |||
| ) | |||
| documents.append(document) | |||
| # save vector index | |||
| index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| if index: | |||
| index.add_texts(documents, duplicate_check=True) | |||
| if dataset.indexing_technique == 'high_quality': | |||
| # save vector index | |||
| vector = Vector( | |||
| dataset=dataset | |||
| ) | |||
| vector.add_texts(documents, duplicate_check=True) | |||
| # save keyword index | |||
| keyword_index = IndexBuilder.get_index(dataset, 'economy') | |||
| if keyword_index: | |||
| keyword_index.multi_create_segment_keywords(pre_segment_data_list) | |||
| keyword = Keyword(dataset) | |||
| if keywords_list and len(keywords_list) > 0: | |||
| keyword.add_texts(documents, keyword_list=keywords_list) | |||
| else: | |||
| keyword.add_texts(documents) | |||
| @classmethod | |||
| def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): | |||
| # update segment index task | |||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| kw_index = IndexBuilder.get_index(dataset, 'economy') | |||
| # delete from vector index | |||
| if vector_index: | |||
| vector_index.delete_by_ids([segment.index_node_id]) | |||
| # delete from keyword index | |||
| kw_index.delete_by_ids([segment.index_node_id]) | |||
| # add new index | |||
| # format new index | |||
| document = Document( | |||
| page_content=segment.content, | |||
| metadata={ | |||
| @@ -82,13 +52,20 @@ class VectorService: | |||
| "dataset_id": segment.dataset_id, | |||
| } | |||
| ) | |||
| if dataset.indexing_technique == 'high_quality': | |||
| # update vector index | |||
| vector = Vector( | |||
| dataset=dataset | |||
| ) | |||
| vector.delete_by_ids([segment.index_node_id]) | |||
| vector.add_texts([document], duplicate_check=True) | |||
| # save vector index | |||
| if vector_index: | |||
| vector_index.add_texts([document], duplicate_check=True) | |||
| # update keyword index | |||
| keyword = Keyword(dataset) | |||
| keyword.delete_by_ids([segment.index_node_id]) | |||
| # save keyword index | |||
| if keywords and len(keywords) > 0: | |||
| kw_index.create_segment_keywords(segment.index_node_id, keywords) | |||
| keyword.add_texts([document], keywords_list=[keywords]) | |||
| else: | |||
| kw_index.add_texts([document]) | |||
| keyword.add_texts([document]) | |||