| 'BILLING_ENABLED': 'False', | 'BILLING_ENABLED': 'False', | ||||
| 'CAN_REPLACE_LOGO': 'False', | 'CAN_REPLACE_LOGO': 'False', | ||||
| 'ETL_TYPE': 'dify', | 'ETL_TYPE': 'dify', | ||||
| 'KEYWORD_STORE': 'jieba', | |||||
| 'BATCH_UPLOAD_LIMIT': 20 | 'BATCH_UPLOAD_LIMIT': 20 | ||||
| } | } | ||||
| # Currently, only support: qdrant, milvus, zilliz, weaviate | # Currently, only support: qdrant, milvus, zilliz, weaviate | ||||
| # ------------------------ | # ------------------------ | ||||
| self.VECTOR_STORE = get_env('VECTOR_STORE') | self.VECTOR_STORE = get_env('VECTOR_STORE') | ||||
| self.KEYWORD_STORE = get_env('KEYWORD_STORE') | |||||
| # qdrant settings | # qdrant settings | ||||
| self.QDRANT_URL = get_env('QDRANT_URL') | self.QDRANT_URL = get_env('QDRANT_URL') | ||||
| self.QDRANT_API_KEY = get_env('QDRANT_API_KEY') | self.QDRANT_API_KEY = get_env('QDRANT_API_KEY') |
| from controllers.console import api | from controllers.console import api | ||||
| from controllers.console.setup import setup_required | from controllers.console.setup import setup_required | ||||
| from controllers.console.wraps import account_initialization_required | from controllers.console.wraps import account_initialization_required | ||||
| from core.data_loader.loader.notion import NotionLoader | |||||
| from core.indexing_runner import IndexingRunner | from core.indexing_runner import IndexingRunner | ||||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||||
| from core.rag.extractor.notion_extractor import NotionExtractor | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields | from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields | ||||
| from libs.login import login_required | from libs.login import login_required | ||||
| if not data_source_binding: | if not data_source_binding: | ||||
| raise NotFound('Data source binding not found.') | raise NotFound('Data source binding not found.') | ||||
| loader = NotionLoader( | |||||
| notion_access_token=data_source_binding.access_token, | |||||
| extractor = NotionExtractor( | |||||
| notion_workspace_id=workspace_id, | notion_workspace_id=workspace_id, | ||||
| notion_obj_id=page_id, | notion_obj_id=page_id, | ||||
| notion_page_type=page_type | |||||
| notion_page_type=page_type, | |||||
| notion_access_token=data_source_binding.access_token | |||||
| ) | ) | ||||
| text_docs = loader.load() | |||||
| text_docs = extractor.extract() | |||||
| return { | return { | ||||
| 'content': "\n".join([doc.page_content for doc in text_docs]) | 'content': "\n".join([doc.page_content for doc in text_docs]) | ||||
| }, 200 | }, 200 | ||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json') | parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json') | ||||
| parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') | parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') | ||||
| parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | |||||
| parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| # validate args | # validate args | ||||
| DocumentService.estimate_args_validate(args) | DocumentService.estimate_args_validate(args) | ||||
| notion_info_list = args['notion_info_list'] | |||||
| extract_settings = [] | |||||
| for notion_info in notion_info_list: | |||||
| workspace_id = notion_info['workspace_id'] | |||||
| for page in notion_info['pages']: | |||||
| extract_setting = ExtractSetting( | |||||
| datasource_type="notion_import", | |||||
| notion_info={ | |||||
| "notion_workspace_id": workspace_id, | |||||
| "notion_obj_id": page['page_id'], | |||||
| "notion_page_type": page['type'] | |||||
| }, | |||||
| document_model=args['doc_form'] | |||||
| ) | |||||
| extract_settings.append(extract_setting) | |||||
| indexing_runner = IndexingRunner() | indexing_runner = IndexingRunner() | ||||
| response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule']) | |||||
| response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, | |||||
| args['process_rule'], args['doc_form'], | |||||
| args['doc_language']) | |||||
| return response, 200 | return response, 200 | ||||
| from core.indexing_runner import IndexingRunner | from core.indexing_runner import IndexingRunner | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.provider_manager import ProviderManager | from core.provider_manager import ProviderManager | ||||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from fields.app_fields import related_app_list | from fields.app_fields import related_app_list | ||||
| from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields | from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields | ||||
| location='json', store_missing=False, | location='json', store_missing=False, | ||||
| type=_validate_description_length) | type=_validate_description_length) | ||||
| parser.add_argument('indexing_technique', type=str, location='json', | parser.add_argument('indexing_technique', type=str, location='json', | ||||
| choices=Dataset.INDEXING_TECHNIQUE_LIST, | |||||
| nullable=True, | |||||
| help='Invalid indexing technique.') | |||||
| choices=Dataset.INDEXING_TECHNIQUE_LIST, | |||||
| nullable=True, | |||||
| help='Invalid indexing technique.') | |||||
| parser.add_argument('permission', type=str, location='json', choices=( | parser.add_argument('permission', type=str, location='json', choices=( | ||||
| 'only_me', 'all_team_members'), help='Invalid permission.') | 'only_me', 'all_team_members'), help='Invalid permission.') | ||||
| parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.') | parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.') | ||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') | parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') | ||||
| parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') | parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') | ||||
| parser.add_argument('indexing_technique', type=str, required=True, | |||||
| parser.add_argument('indexing_technique', type=str, required=True, | |||||
| choices=Dataset.INDEXING_TECHNIQUE_LIST, | choices=Dataset.INDEXING_TECHNIQUE_LIST, | ||||
| nullable=True, location='json') | nullable=True, location='json') | ||||
| parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| # validate args | # validate args | ||||
| DocumentService.estimate_args_validate(args) | DocumentService.estimate_args_validate(args) | ||||
| extract_settings = [] | |||||
| if args['info_list']['data_source_type'] == 'upload_file': | if args['info_list']['data_source_type'] == 'upload_file': | ||||
| file_ids = args['info_list']['file_info_list']['file_ids'] | file_ids = args['info_list']['file_info_list']['file_ids'] | ||||
| file_details = db.session.query(UploadFile).filter( | file_details = db.session.query(UploadFile).filter( | ||||
| if file_details is None: | if file_details is None: | ||||
| raise NotFound("File not found.") | raise NotFound("File not found.") | ||||
| indexing_runner = IndexingRunner() | |||||
| try: | |||||
| response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, | |||||
| args['process_rule'], args['doc_form'], | |||||
| args['doc_language'], args['dataset_id'], | |||||
| args['indexing_technique']) | |||||
| except LLMBadRequestError: | |||||
| raise ProviderNotInitializeError( | |||||
| "No Embedding Model available. Please configure a valid provider " | |||||
| "in the Settings -> Model Provider.") | |||||
| except ProviderTokenNotInitError as ex: | |||||
| raise ProviderNotInitializeError(ex.description) | |||||
| if file_details: | |||||
| for file_detail in file_details: | |||||
| extract_setting = ExtractSetting( | |||||
| datasource_type="upload_file", | |||||
| upload_file=file_detail, | |||||
| document_model=args['doc_form'] | |||||
| ) | |||||
| extract_settings.append(extract_setting) | |||||
| elif args['info_list']['data_source_type'] == 'notion_import': | elif args['info_list']['data_source_type'] == 'notion_import': | ||||
| indexing_runner = IndexingRunner() | |||||
| try: | |||||
| response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, | |||||
| args['info_list']['notion_info_list'], | |||||
| args['process_rule'], args['doc_form'], | |||||
| args['doc_language'], args['dataset_id'], | |||||
| args['indexing_technique']) | |||||
| except LLMBadRequestError: | |||||
| raise ProviderNotInitializeError( | |||||
| "No Embedding Model available. Please configure a valid provider " | |||||
| "in the Settings -> Model Provider.") | |||||
| except ProviderTokenNotInitError as ex: | |||||
| raise ProviderNotInitializeError(ex.description) | |||||
| notion_info_list = args['info_list']['notion_info_list'] | |||||
| for notion_info in notion_info_list: | |||||
| workspace_id = notion_info['workspace_id'] | |||||
| for page in notion_info['pages']: | |||||
| extract_setting = ExtractSetting( | |||||
| datasource_type="notion_import", | |||||
| notion_info={ | |||||
| "notion_workspace_id": workspace_id, | |||||
| "notion_obj_id": page['page_id'], | |||||
| "notion_page_type": page['type'] | |||||
| }, | |||||
| document_model=args['doc_form'] | |||||
| ) | |||||
| extract_settings.append(extract_setting) | |||||
| else: | else: | ||||
| raise ValueError('Data source type not support') | raise ValueError('Data source type not support') | ||||
| indexing_runner = IndexingRunner() | |||||
| try: | |||||
| response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, | |||||
| args['process_rule'], args['doc_form'], | |||||
| args['doc_language'], args['dataset_id'], | |||||
| args['indexing_technique']) | |||||
| except LLMBadRequestError: | |||||
| raise ProviderNotInitializeError( | |||||
| "No Embedding Model available. Please configure a valid provider " | |||||
| "in the Settings -> Model Provider.") | |||||
| except ProviderTokenNotInitError as ex: | |||||
| raise ProviderNotInitializeError(ex.description) | |||||
| return response, 200 | return response, 200 | ||||
| api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') | api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') | ||||
| api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') | api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') | ||||
| api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>') | api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>') | ||||
| from core.model_manager import ModelManager | from core.model_manager import ModelManager | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | from core.model_runtime.errors.invoke import InvokeAuthorizationError | ||||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from fields.document_fields import ( | from fields.document_fields import ( | ||||
| req_data = request.args | req_data = request.args | ||||
| document_id = req_data.get('document_id') | document_id = req_data.get('document_id') | ||||
| # get default rules | # get default rules | ||||
| mode = DocumentService.DEFAULT_RULES['mode'] | mode = DocumentService.DEFAULT_RULES['mode'] | ||||
| rules = DocumentService.DEFAULT_RULES['rules'] | rules = DocumentService.DEFAULT_RULES['rules'] | ||||
| if not file: | if not file: | ||||
| raise NotFound('File not found.') | raise NotFound('File not found.') | ||||
| extract_setting = ExtractSetting( | |||||
| datasource_type="upload_file", | |||||
| upload_file=file, | |||||
| document_model=document.doc_form | |||||
| ) | |||||
| indexing_runner = IndexingRunner() | indexing_runner = IndexingRunner() | ||||
| try: | try: | ||||
| response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file], | |||||
| data_process_rule_dict, None, | |||||
| 'English', dataset_id) | |||||
| response = indexing_runner.indexing_estimate(current_user.current_tenant_id, [extract_setting], | |||||
| data_process_rule_dict, document.doc_form, | |||||
| 'English', dataset_id) | |||||
| except LLMBadRequestError: | except LLMBadRequestError: | ||||
| raise ProviderNotInitializeError( | raise ProviderNotInitializeError( | ||||
| "No Embedding Model available. Please configure a valid provider " | "No Embedding Model available. Please configure a valid provider " | ||||
| data_process_rule = documents[0].dataset_process_rule | data_process_rule = documents[0].dataset_process_rule | ||||
| data_process_rule_dict = data_process_rule.to_dict() | data_process_rule_dict = data_process_rule.to_dict() | ||||
| info_list = [] | info_list = [] | ||||
| extract_settings = [] | |||||
| for document in documents: | for document in documents: | ||||
| if document.indexing_status in ['completed', 'error']: | if document.indexing_status in ['completed', 'error']: | ||||
| raise DocumentAlreadyFinishedError() | raise DocumentAlreadyFinishedError() | ||||
| } | } | ||||
| info_list.append(notion_info) | info_list.append(notion_info) | ||||
| if dataset.data_source_type == 'upload_file': | |||||
| file_details = db.session.query(UploadFile).filter( | |||||
| UploadFile.tenant_id == current_user.current_tenant_id, | |||||
| UploadFile.id.in_(info_list) | |||||
| ).all() | |||||
| if document.data_source_type == 'upload_file': | |||||
| file_id = data_source_info['upload_file_id'] | |||||
| file_detail = db.session.query(UploadFile).filter( | |||||
| UploadFile.tenant_id == current_user.current_tenant_id, | |||||
| UploadFile.id == file_id | |||||
| ).first() | |||||
| if file_details is None: | |||||
| raise NotFound("File not found.") | |||||
| if file_detail is None: | |||||
| raise NotFound("File not found.") | |||||
| indexing_runner = IndexingRunner() | |||||
| try: | |||||
| response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, | |||||
| data_process_rule_dict, None, | |||||
| 'English', dataset_id) | |||||
| except LLMBadRequestError: | |||||
| raise ProviderNotInitializeError( | |||||
| "No Embedding Model available. Please configure a valid provider " | |||||
| "in the Settings -> Model Provider.") | |||||
| except ProviderTokenNotInitError as ex: | |||||
| raise ProviderNotInitializeError(ex.description) | |||||
| elif dataset.data_source_type == 'notion_import': | |||||
| extract_setting = ExtractSetting( | |||||
| datasource_type="upload_file", | |||||
| upload_file=file_detail, | |||||
| document_model=document.doc_form | |||||
| ) | |||||
| extract_settings.append(extract_setting) | |||||
| elif document.data_source_type == 'notion_import': | |||||
| extract_setting = ExtractSetting( | |||||
| datasource_type="notion_import", | |||||
| notion_info={ | |||||
| "notion_workspace_id": data_source_info['notion_workspace_id'], | |||||
| "notion_obj_id": data_source_info['notion_page_id'], | |||||
| "notion_page_type": data_source_info['type'] | |||||
| }, | |||||
| document_model=document.doc_form | |||||
| ) | |||||
| extract_settings.append(extract_setting) | |||||
| else: | |||||
| raise ValueError('Data source type not support') | |||||
| indexing_runner = IndexingRunner() | indexing_runner = IndexingRunner() | ||||
| try: | try: | ||||
| response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, | |||||
| info_list, | |||||
| data_process_rule_dict, | |||||
| None, 'English', dataset_id) | |||||
| response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, | |||||
| data_process_rule_dict, document.doc_form, | |||||
| 'English', dataset_id) | |||||
| except LLMBadRequestError: | except LLMBadRequestError: | ||||
| raise ProviderNotInitializeError( | raise ProviderNotInitializeError( | ||||
| "No Embedding Model available. Please configure a valid provider " | "No Embedding Model available. Please configure a valid provider " | ||||
| "in the Settings -> Model Provider.") | "in the Settings -> Model Provider.") | ||||
| except ProviderTokenNotInitError as ex: | except ProviderTokenNotInitError as ex: | ||||
| raise ProviderNotInitializeError(ex.description) | raise ProviderNotInitializeError(ex.description) | ||||
| else: | |||||
| raise ValueError('Data source type not support') | |||||
| return response | return response | ||||
| import tempfile | |||||
| from pathlib import Path | |||||
| from typing import Optional, Union | |||||
| import requests | |||||
| from flask import current_app | |||||
| from langchain.document_loaders import Docx2txtLoader, TextLoader | |||||
| from langchain.schema import Document | |||||
| from core.data_loader.loader.csv_loader import CSVLoader | |||||
| from core.data_loader.loader.excel import ExcelLoader | |||||
| from core.data_loader.loader.html import HTMLLoader | |||||
| from core.data_loader.loader.markdown import MarkdownLoader | |||||
| from core.data_loader.loader.pdf import PdfLoader | |||||
| from core.data_loader.loader.unstructured.unstructured_eml import UnstructuredEmailLoader | |||||
| from core.data_loader.loader.unstructured.unstructured_markdown import UnstructuredMarkdownLoader | |||||
| from core.data_loader.loader.unstructured.unstructured_msg import UnstructuredMsgLoader | |||||
| from core.data_loader.loader.unstructured.unstructured_ppt import UnstructuredPPTLoader | |||||
| from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredPPTXLoader | |||||
| from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader | |||||
| from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader | |||||
| from extensions.ext_storage import storage | |||||
| from models.model import UploadFile | |||||
| SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain'] | |||||
| USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" | |||||
| class FileExtractor: | |||||
| @classmethod | |||||
| def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[list[Document], str]: | |||||
| with tempfile.TemporaryDirectory() as temp_dir: | |||||
| suffix = Path(upload_file.key).suffix | |||||
| file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" | |||||
| storage.download(upload_file.key, file_path) | |||||
| return cls.load_from_file(file_path, return_text, upload_file, is_automatic) | |||||
| @classmethod | |||||
| def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]: | |||||
| response = requests.get(url, headers={ | |||||
| "User-Agent": USER_AGENT | |||||
| }) | |||||
| with tempfile.TemporaryDirectory() as temp_dir: | |||||
| suffix = Path(url).suffix | |||||
| file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" | |||||
| with open(file_path, 'wb') as file: | |||||
| file.write(response.content) | |||||
| return cls.load_from_file(file_path, return_text) | |||||
| @classmethod | |||||
| def load_from_file(cls, file_path: str, return_text: bool = False, | |||||
| upload_file: Optional[UploadFile] = None, | |||||
| is_automatic: bool = False) -> Union[list[Document], str]: | |||||
| input_file = Path(file_path) | |||||
| delimiter = '\n' | |||||
| file_extension = input_file.suffix.lower() | |||||
| etl_type = current_app.config['ETL_TYPE'] | |||||
| unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL'] | |||||
| if etl_type == 'Unstructured': | |||||
| if file_extension == '.xlsx': | |||||
| loader = ExcelLoader(file_path) | |||||
| elif file_extension == '.pdf': | |||||
| loader = PdfLoader(file_path, upload_file=upload_file) | |||||
| elif file_extension in ['.md', '.markdown']: | |||||
| loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url) if is_automatic \ | |||||
| else MarkdownLoader(file_path, autodetect_encoding=True) | |||||
| elif file_extension in ['.htm', '.html']: | |||||
| loader = HTMLLoader(file_path) | |||||
| elif file_extension in ['.docx']: | |||||
| loader = Docx2txtLoader(file_path) | |||||
| elif file_extension == '.csv': | |||||
| loader = CSVLoader(file_path, autodetect_encoding=True) | |||||
| elif file_extension == '.msg': | |||||
| loader = UnstructuredMsgLoader(file_path, unstructured_api_url) | |||||
| elif file_extension == '.eml': | |||||
| loader = UnstructuredEmailLoader(file_path, unstructured_api_url) | |||||
| elif file_extension == '.ppt': | |||||
| loader = UnstructuredPPTLoader(file_path, unstructured_api_url) | |||||
| elif file_extension == '.pptx': | |||||
| loader = UnstructuredPPTXLoader(file_path, unstructured_api_url) | |||||
| elif file_extension == '.xml': | |||||
| loader = UnstructuredXmlLoader(file_path, unstructured_api_url) | |||||
| else: | |||||
| # txt | |||||
| loader = UnstructuredTextLoader(file_path, unstructured_api_url) if is_automatic \ | |||||
| else TextLoader(file_path, autodetect_encoding=True) | |||||
| else: | |||||
| if file_extension == '.xlsx': | |||||
| loader = ExcelLoader(file_path) | |||||
| elif file_extension == '.pdf': | |||||
| loader = PdfLoader(file_path, upload_file=upload_file) | |||||
| elif file_extension in ['.md', '.markdown']: | |||||
| loader = MarkdownLoader(file_path, autodetect_encoding=True) | |||||
| elif file_extension in ['.htm', '.html']: | |||||
| loader = HTMLLoader(file_path) | |||||
| elif file_extension in ['.docx']: | |||||
| loader = Docx2txtLoader(file_path) | |||||
| elif file_extension == '.csv': | |||||
| loader = CSVLoader(file_path, autodetect_encoding=True) | |||||
| else: | |||||
| # txt | |||||
| loader = TextLoader(file_path, autodetect_encoding=True) | |||||
| return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load() |
| import logging | |||||
| from bs4 import BeautifulSoup | |||||
| from langchain.document_loaders.base import BaseLoader | |||||
| from langchain.schema import Document | |||||
| logger = logging.getLogger(__name__) | |||||
| class HTMLLoader(BaseLoader): | |||||
| """Load html files. | |||||
| Args: | |||||
| file_path: Path to the file to load. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| file_path: str | |||||
| ): | |||||
| """Initialize with file path.""" | |||||
| self._file_path = file_path | |||||
| def load(self) -> list[Document]: | |||||
| return [Document(page_content=self._load_as_text())] | |||||
| def _load_as_text(self) -> str: | |||||
| with open(self._file_path, "rb") as fp: | |||||
| soup = BeautifulSoup(fp, 'html.parser') | |||||
| text = soup.get_text() | |||||
| text = text.strip() if text else '' | |||||
| return text |
| import logging | |||||
| from typing import Optional | |||||
| from langchain.document_loaders import PyPDFium2Loader | |||||
| from langchain.document_loaders.base import BaseLoader | |||||
| from langchain.schema import Document | |||||
| from extensions.ext_storage import storage | |||||
| from models.model import UploadFile | |||||
| logger = logging.getLogger(__name__) | |||||
| class PdfLoader(BaseLoader): | |||||
| """Load pdf files. | |||||
| Args: | |||||
| file_path: Path to the file to load. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| file_path: str, | |||||
| upload_file: Optional[UploadFile] = None | |||||
| ): | |||||
| """Initialize with file path.""" | |||||
| self._file_path = file_path | |||||
| self._upload_file = upload_file | |||||
| def load(self) -> list[Document]: | |||||
| plaintext_file_key = '' | |||||
| plaintext_file_exists = False | |||||
| if self._upload_file: | |||||
| if self._upload_file.hash: | |||||
| plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \ | |||||
| + self._upload_file.hash + '.0625.plaintext' | |||||
| try: | |||||
| text = storage.load(plaintext_file_key).decode('utf-8') | |||||
| plaintext_file_exists = True | |||||
| return [Document(page_content=text)] | |||||
| except FileNotFoundError: | |||||
| pass | |||||
| documents = PyPDFium2Loader(file_path=self._file_path).load() | |||||
| text_list = [] | |||||
| for document in documents: | |||||
| text_list.append(document.page_content) | |||||
| text = "\n\n".join(text_list) | |||||
| # save plaintext file for caching | |||||
| if not plaintext_file_exists and plaintext_file_key: | |||||
| storage.save(plaintext_file_key, text.encode('utf-8')) | |||||
| return documents | |||||
| from collections.abc import Sequence | from collections.abc import Sequence | ||||
| from typing import Any, Optional, cast | from typing import Any, Optional, cast | ||||
| from langchain.schema import Document | |||||
| from sqlalchemy import func | from sqlalchemy import func | ||||
| from core.model_manager import ModelManager | from core.model_manager import ModelManager | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | ||||
| from core.rag.models.document import Document | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import Dataset, DocumentSegment | from models.dataset import Dataset, DocumentSegment | ||||
| import logging | import logging | ||||
| from typing import Optional | from typing import Optional | ||||
| from flask import current_app | |||||
| from core.embedding.cached_embedding import CacheEmbedding | |||||
| from core.entities.application_entities import InvokeFrom | from core.entities.application_entities import InvokeFrom | ||||
| from core.index.vector_index.vector_index import VectorIndex | |||||
| from core.model_manager import ModelManager | |||||
| from core.model_runtime.entities.model_entities import ModelType | |||||
| from core.rag.datasource.vdb.vector_factory import Vector | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import Dataset | from models.dataset import Dataset | ||||
| from models.model import App, AppAnnotationSetting, Message, MessageAnnotation | from models.model import App, AppAnnotationSetting, Message, MessageAnnotation | ||||
| embedding_provider_name = collection_binding_detail.provider_name | embedding_provider_name = collection_binding_detail.provider_name | ||||
| embedding_model_name = collection_binding_detail.model_name | embedding_model_name = collection_binding_detail.model_name | ||||
| model_manager = ModelManager() | |||||
| model_instance = model_manager.get_model_instance( | |||||
| tenant_id=app_record.tenant_id, | |||||
| provider=embedding_provider_name, | |||||
| model_type=ModelType.TEXT_EMBEDDING, | |||||
| model=embedding_model_name | |||||
| ) | |||||
| # get embedding model | |||||
| embeddings = CacheEmbedding(model_instance) | |||||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( | dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( | ||||
| embedding_provider_name, | embedding_provider_name, | ||||
| embedding_model_name, | embedding_model_name, | ||||
| collection_binding_id=dataset_collection_binding.id | collection_binding_id=dataset_collection_binding.id | ||||
| ) | ) | ||||
| vector_index = VectorIndex( | |||||
| dataset=dataset, | |||||
| config=current_app.config, | |||||
| embeddings=embeddings, | |||||
| attributes=['doc_id', 'annotation_id', 'app_id'] | |||||
| ) | |||||
| vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) | |||||
| documents = vector_index.search( | |||||
| documents = vector.search_by_vector( | |||||
| query=query, | query=query, | ||||
| search_type='similarity_score_threshold', | |||||
| search_kwargs={ | |||||
| 'k': 1, | |||||
| 'score_threshold': score_threshold, | |||||
| 'filter': { | |||||
| 'group_id': [dataset.id] | |||||
| } | |||||
| k=1, | |||||
| score_threshold=score_threshold, | |||||
| filter={ | |||||
| 'group_id': [dataset.id] | |||||
| } | } | ||||
| ) | ) | ||||
| from flask import current_app | |||||
| from langchain.embeddings import OpenAIEmbeddings | |||||
| from core.embedding.cached_embedding import CacheEmbedding | |||||
| from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex | |||||
| from core.index.vector_index.vector_index import VectorIndex | |||||
| from core.model_manager import ModelManager | |||||
| from core.model_runtime.entities.model_entities import ModelType | |||||
| from models.dataset import Dataset | |||||
| class IndexBuilder: | |||||
| @classmethod | |||||
| def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False): | |||||
| if indexing_technique == "high_quality": | |||||
| if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality': | |||||
| return None | |||||
| model_manager = ModelManager() | |||||
| embedding_model = model_manager.get_model_instance( | |||||
| tenant_id=dataset.tenant_id, | |||||
| model_type=ModelType.TEXT_EMBEDDING, | |||||
| provider=dataset.embedding_model_provider, | |||||
| model=dataset.embedding_model | |||||
| ) | |||||
| embeddings = CacheEmbedding(embedding_model) | |||||
| return VectorIndex( | |||||
| dataset=dataset, | |||||
| config=current_app.config, | |||||
| embeddings=embeddings | |||||
| ) | |||||
| elif indexing_technique == "economy": | |||||
| return KeywordTableIndex( | |||||
| dataset=dataset, | |||||
| config=KeywordTableConfig( | |||||
| max_keywords_per_chunk=10 | |||||
| ) | |||||
| ) | |||||
| else: | |||||
| raise ValueError('Unknown indexing technique') | |||||
| @classmethod | |||||
| def get_default_high_quality_index(cls, dataset: Dataset): | |||||
| embeddings = OpenAIEmbeddings(openai_api_key=' ') | |||||
| return VectorIndex( | |||||
| dataset=dataset, | |||||
| config=current_app.config, | |||||
| embeddings=embeddings | |||||
| ) |
| import json | |||||
| import logging | |||||
| from abc import abstractmethod | |||||
| from typing import Any, cast | |||||
| from langchain.embeddings.base import Embeddings | |||||
| from langchain.schema import BaseRetriever, Document | |||||
| from langchain.vectorstores import VectorStore | |||||
| from core.index.base import BaseIndex | |||||
| from extensions.ext_database import db | |||||
| from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment | |||||
| from models.dataset import Document as DatasetDocument | |||||
| class BaseVectorIndex(BaseIndex): | |||||
| def __init__(self, dataset: Dataset, embeddings: Embeddings): | |||||
| super().__init__(dataset) | |||||
| self._embeddings = embeddings | |||||
| self._vector_store = None | |||||
| def get_type(self) -> str: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def get_index_name(self, dataset: Dataset) -> str: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def to_index_struct(self) -> dict: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def _get_vector_store(self) -> VectorStore: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def _get_vector_store_class(self) -> type: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def search_by_full_text_index( | |||||
| self, query: str, | |||||
| **kwargs: Any | |||||
| ) -> list[Document]: | |||||
| raise NotImplementedError | |||||
| def search( | |||||
| self, query: str, | |||||
| **kwargs: Any | |||||
| ) -> list[Document]: | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity' | |||||
| search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {} | |||||
| if search_type == 'similarity_score_threshold': | |||||
| score_threshold = search_kwargs.get("score_threshold") | |||||
| if (score_threshold is None) or (not isinstance(score_threshold, float)): | |||||
| search_kwargs['score_threshold'] = .0 | |||||
| docs_with_similarity = vector_store.similarity_search_with_relevance_scores( | |||||
| query, **search_kwargs | |||||
| ) | |||||
| docs = [] | |||||
| for doc, similarity in docs_with_similarity: | |||||
| doc.metadata['score'] = similarity | |||||
| docs.append(doc) | |||||
| return docs | |||||
| # similarity k | |||||
| # mmr k, fetch_k, lambda_mult | |||||
| # similarity_score_threshold k | |||||
| return vector_store.as_retriever( | |||||
| search_type=search_type, | |||||
| search_kwargs=search_kwargs | |||||
| ).get_relevant_documents(query) | |||||
| def get_retriever(self, **kwargs: Any) -> BaseRetriever: | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| return vector_store.as_retriever(**kwargs) | |||||
| def add_texts(self, texts: list[Document], **kwargs): | |||||
| if self._is_origin(): | |||||
| self.recreate_dataset(self.dataset) | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| if kwargs.get('duplicate_check', False): | |||||
| texts = self._filter_duplicate_texts(texts) | |||||
| uuids = self._get_uuids(texts) | |||||
| vector_store.add_documents(texts, uuids=uuids) | |||||
| def text_exists(self, id: str) -> bool: | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| return vector_store.text_exists(id) | |||||
| def delete_by_ids(self, ids: list[str]) -> None: | |||||
| if self._is_origin(): | |||||
| self.recreate_dataset(self.dataset) | |||||
| return | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| for node_id in ids: | |||||
| vector_store.del_text(node_id) | |||||
| def delete_by_group_id(self, group_id: str) -> None: | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| if self.dataset.collection_binding_id: | |||||
| vector_store.delete_by_group_id(group_id) | |||||
| else: | |||||
| vector_store.delete() | |||||
| def delete(self) -> None: | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| vector_store.delete() | |||||
| def _is_origin(self): | |||||
| return False | |||||
| def recreate_dataset(self, dataset: Dataset): | |||||
| logging.info(f"Recreating dataset {dataset.id}") | |||||
| try: | |||||
| self.delete() | |||||
| except Exception as e: | |||||
| raise e | |||||
| dataset_documents = db.session.query(DatasetDocument).filter( | |||||
| DatasetDocument.dataset_id == dataset.id, | |||||
| DatasetDocument.indexing_status == 'completed', | |||||
| DatasetDocument.enabled == True, | |||||
| DatasetDocument.archived == False, | |||||
| ).all() | |||||
| documents = [] | |||||
| for dataset_document in dataset_documents: | |||||
| segments = db.session.query(DocumentSegment).filter( | |||||
| DocumentSegment.document_id == dataset_document.id, | |||||
| DocumentSegment.status == 'completed', | |||||
| DocumentSegment.enabled == True | |||||
| ).all() | |||||
| for segment in segments: | |||||
| document = Document( | |||||
| page_content=segment.content, | |||||
| metadata={ | |||||
| "doc_id": segment.index_node_id, | |||||
| "doc_hash": segment.index_node_hash, | |||||
| "document_id": segment.document_id, | |||||
| "dataset_id": segment.dataset_id, | |||||
| } | |||||
| ) | |||||
| documents.append(document) | |||||
| origin_index_struct = self.dataset.index_struct[:] | |||||
| self.dataset.index_struct = None | |||||
| if documents: | |||||
| try: | |||||
| self.create(documents) | |||||
| except Exception as e: | |||||
| self.dataset.index_struct = origin_index_struct | |||||
| raise e | |||||
| dataset.index_struct = json.dumps(self.to_index_struct()) | |||||
| db.session.commit() | |||||
| self.dataset = dataset | |||||
| logging.info(f"Dataset {dataset.id} recreate successfully.") | |||||
| def create_qdrant_dataset(self, dataset: Dataset): | |||||
| logging.info(f"create_qdrant_dataset {dataset.id}") | |||||
| try: | |||||
| self.delete() | |||||
| except Exception as e: | |||||
| raise e | |||||
| dataset_documents = db.session.query(DatasetDocument).filter( | |||||
| DatasetDocument.dataset_id == dataset.id, | |||||
| DatasetDocument.indexing_status == 'completed', | |||||
| DatasetDocument.enabled == True, | |||||
| DatasetDocument.archived == False, | |||||
| ).all() | |||||
| documents = [] | |||||
| for dataset_document in dataset_documents: | |||||
| segments = db.session.query(DocumentSegment).filter( | |||||
| DocumentSegment.document_id == dataset_document.id, | |||||
| DocumentSegment.status == 'completed', | |||||
| DocumentSegment.enabled == True | |||||
| ).all() | |||||
| for segment in segments: | |||||
| document = Document( | |||||
| page_content=segment.content, | |||||
| metadata={ | |||||
| "doc_id": segment.index_node_id, | |||||
| "doc_hash": segment.index_node_hash, | |||||
| "document_id": segment.document_id, | |||||
| "dataset_id": segment.dataset_id, | |||||
| } | |||||
| ) | |||||
| documents.append(document) | |||||
| if documents: | |||||
| try: | |||||
| self.create(documents) | |||||
| except Exception as e: | |||||
| raise e | |||||
| logging.info(f"Dataset {dataset.id} recreate successfully.") | |||||
| def update_qdrant_dataset(self, dataset: Dataset): | |||||
| logging.info(f"update_qdrant_dataset {dataset.id}") | |||||
| segment = db.session.query(DocumentSegment).filter( | |||||
| DocumentSegment.dataset_id == dataset.id, | |||||
| DocumentSegment.status == 'completed', | |||||
| DocumentSegment.enabled == True | |||||
| ).first() | |||||
| if segment: | |||||
| try: | |||||
| exist = self.text_exists(segment.index_node_id) | |||||
| if exist: | |||||
| index_struct = { | |||||
| "type": 'qdrant', | |||||
| "vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']} | |||||
| } | |||||
| dataset.index_struct = json.dumps(index_struct) | |||||
| db.session.commit() | |||||
| except Exception as e: | |||||
| raise e | |||||
| logging.info(f"Dataset {dataset.id} recreate successfully.") | |||||
| def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding): | |||||
| logging.info(f"restore dataset in_one,_dataset {dataset.id}") | |||||
| dataset_documents = db.session.query(DatasetDocument).filter( | |||||
| DatasetDocument.dataset_id == dataset.id, | |||||
| DatasetDocument.indexing_status == 'completed', | |||||
| DatasetDocument.enabled == True, | |||||
| DatasetDocument.archived == False, | |||||
| ).all() | |||||
| documents = [] | |||||
| for dataset_document in dataset_documents: | |||||
| segments = db.session.query(DocumentSegment).filter( | |||||
| DocumentSegment.document_id == dataset_document.id, | |||||
| DocumentSegment.status == 'completed', | |||||
| DocumentSegment.enabled == True | |||||
| ).all() | |||||
| for segment in segments: | |||||
| document = Document( | |||||
| page_content=segment.content, | |||||
| metadata={ | |||||
| "doc_id": segment.index_node_id, | |||||
| "doc_hash": segment.index_node_hash, | |||||
| "document_id": segment.document_id, | |||||
| "dataset_id": segment.dataset_id, | |||||
| } | |||||
| ) | |||||
| documents.append(document) | |||||
| if documents: | |||||
| try: | |||||
| self.add_texts(documents) | |||||
| except Exception as e: | |||||
| raise e | |||||
| logging.info(f"Dataset {dataset.id} recreate successfully.") | |||||
| def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding): | |||||
| logging.info(f"delete original collection: {dataset.id}") | |||||
| self.delete() | |||||
| dataset.collection_binding_id = dataset_collection_binding.id | |||||
| db.session.add(dataset) | |||||
| db.session.commit() | |||||
| logging.info(f"Dataset {dataset.id} recreate successfully.") |
| from typing import Any, cast | |||||
| from langchain.embeddings.base import Embeddings | |||||
| from langchain.schema import Document | |||||
| from langchain.vectorstores import VectorStore | |||||
| from pydantic import BaseModel, root_validator | |||||
| from core.index.base import BaseIndex | |||||
| from core.index.vector_index.base import BaseVectorIndex | |||||
| from core.vector_store.milvus_vector_store import MilvusVectorStore | |||||
| from models.dataset import Dataset | |||||
| class MilvusConfig(BaseModel): | |||||
| host: str | |||||
| port: int | |||||
| user: str | |||||
| password: str | |||||
| secure: bool = False | |||||
| batch_size: int = 100 | |||||
| @root_validator() | |||||
| def validate_config(cls, values: dict) -> dict: | |||||
| if not values['host']: | |||||
| raise ValueError("config MILVUS_HOST is required") | |||||
| if not values['port']: | |||||
| raise ValueError("config MILVUS_PORT is required") | |||||
| if not values['user']: | |||||
| raise ValueError("config MILVUS_USER is required") | |||||
| if not values['password']: | |||||
| raise ValueError("config MILVUS_PASSWORD is required") | |||||
| return values | |||||
| def to_milvus_params(self): | |||||
| return { | |||||
| 'host': self.host, | |||||
| 'port': self.port, | |||||
| 'user': self.user, | |||||
| 'password': self.password, | |||||
| 'secure': self.secure | |||||
| } | |||||
| class MilvusVectorIndex(BaseVectorIndex): | |||||
| def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings): | |||||
| super().__init__(dataset, embeddings) | |||||
| self._client_config = config | |||||
| def get_type(self) -> str: | |||||
| return 'milvus' | |||||
| def get_index_name(self, dataset: Dataset) -> str: | |||||
| if self.dataset.index_struct_dict: | |||||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| if not class_prefix.endswith('_Node'): | |||||
| # original class_prefix | |||||
| class_prefix += '_Node' | |||||
| return class_prefix | |||||
| dataset_id = dataset.id | |||||
| return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' | |||||
| def to_index_struct(self) -> dict: | |||||
| return { | |||||
| "type": self.get_type(), | |||||
| "vector_store": {"class_prefix": self.get_index_name(self.dataset)} | |||||
| } | |||||
| def create(self, texts: list[Document], **kwargs) -> BaseIndex: | |||||
| uuids = self._get_uuids(texts) | |||||
| index_params = { | |||||
| 'metric_type': 'IP', | |||||
| 'index_type': "HNSW", | |||||
| 'params': {"M": 8, "efConstruction": 64} | |||||
| } | |||||
| self._vector_store = MilvusVectorStore.from_documents( | |||||
| texts, | |||||
| self._embeddings, | |||||
| collection_name=self.get_index_name(self.dataset), | |||||
| connection_args=self._client_config.to_milvus_params(), | |||||
| index_params=index_params | |||||
| ) | |||||
| return self | |||||
| def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: | |||||
| uuids = self._get_uuids(texts) | |||||
| self._vector_store = MilvusVectorStore.from_documents( | |||||
| texts, | |||||
| self._embeddings, | |||||
| collection_name=collection_name, | |||||
| ids=uuids, | |||||
| content_payload_key='page_content' | |||||
| ) | |||||
| return self | |||||
| def _get_vector_store(self) -> VectorStore: | |||||
| """Only for created index.""" | |||||
| if self._vector_store: | |||||
| return self._vector_store | |||||
| return MilvusVectorStore( | |||||
| collection_name=self.get_index_name(self.dataset), | |||||
| embedding_function=self._embeddings, | |||||
| connection_args=self._client_config.to_milvus_params() | |||||
| ) | |||||
| def _get_vector_store_class(self) -> type: | |||||
| return MilvusVectorStore | |||||
| def delete_by_document_id(self, document_id: str): | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| ids = vector_store.get_ids_by_document_id(document_id) | |||||
| if ids: | |||||
| vector_store.del_texts({ | |||||
| 'filter': f'id in {ids}' | |||||
| }) | |||||
| def delete_by_metadata_field(self, key: str, value: str): | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| ids = vector_store.get_ids_by_metadata_field(key, value) | |||||
| if ids: | |||||
| vector_store.del_texts({ | |||||
| 'filter': f'id in {ids}' | |||||
| }) | |||||
| def delete_by_ids(self, doc_ids: list[str]) -> None: | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| ids = vector_store.get_ids_by_doc_ids(doc_ids) | |||||
| vector_store.del_texts({ | |||||
| 'filter': f' id in {ids}' | |||||
| }) | |||||
| def delete_by_group_id(self, group_id: str) -> None: | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| vector_store.delete() | |||||
| def delete(self) -> None: | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| from qdrant_client.http import models | |||||
| vector_store.del_texts(models.Filter( | |||||
| must=[ | |||||
| models.FieldCondition( | |||||
| key="group_id", | |||||
| match=models.MatchValue(value=self.dataset.id), | |||||
| ), | |||||
| ], | |||||
| )) | |||||
| def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: | |||||
| # milvus/zilliz doesn't support bm25 search | |||||
| return [] |
| import os | |||||
| from typing import Any, Optional, cast | |||||
| import qdrant_client | |||||
| from langchain.embeddings.base import Embeddings | |||||
| from langchain.schema import Document | |||||
| from langchain.vectorstores import VectorStore | |||||
| from pydantic import BaseModel | |||||
| from qdrant_client.http.models import HnswConfigDiff | |||||
| from core.index.base import BaseIndex | |||||
| from core.index.vector_index.base import BaseVectorIndex | |||||
| from core.vector_store.qdrant_vector_store import QdrantVectorStore | |||||
| from extensions.ext_database import db | |||||
| from models.dataset import Dataset, DatasetCollectionBinding | |||||
| class QdrantConfig(BaseModel): | |||||
| endpoint: str | |||||
| api_key: Optional[str] | |||||
| timeout: float = 20 | |||||
| root_path: Optional[str] | |||||
| def to_qdrant_params(self): | |||||
| if self.endpoint and self.endpoint.startswith('path:'): | |||||
| path = self.endpoint.replace('path:', '') | |||||
| if not os.path.isabs(path): | |||||
| path = os.path.join(self.root_path, path) | |||||
| return { | |||||
| 'path': path | |||||
| } | |||||
| else: | |||||
| return { | |||||
| 'url': self.endpoint, | |||||
| 'api_key': self.api_key, | |||||
| 'timeout': self.timeout | |||||
| } | |||||
| class QdrantVectorIndex(BaseVectorIndex): | |||||
| def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings): | |||||
| super().__init__(dataset, embeddings) | |||||
| self._client_config = config | |||||
| def get_type(self) -> str: | |||||
| return 'qdrant' | |||||
| def get_index_name(self, dataset: Dataset) -> str: | |||||
| if dataset.collection_binding_id: | |||||
| dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ | |||||
| filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ | |||||
| one_or_none() | |||||
| if dataset_collection_binding: | |||||
| return dataset_collection_binding.collection_name | |||||
| else: | |||||
| raise ValueError('Dataset Collection Bindings is not exist!') | |||||
| else: | |||||
| if self.dataset.index_struct_dict: | |||||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| return class_prefix | |||||
| dataset_id = dataset.id | |||||
| return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' | |||||
| def to_index_struct(self) -> dict: | |||||
| return { | |||||
| "type": self.get_type(), | |||||
| "vector_store": {"class_prefix": self.get_index_name(self.dataset)} | |||||
| } | |||||
| def create(self, texts: list[Document], **kwargs) -> BaseIndex: | |||||
| uuids = self._get_uuids(texts) | |||||
| self._vector_store = QdrantVectorStore.from_documents( | |||||
| texts, | |||||
| self._embeddings, | |||||
| collection_name=self.get_index_name(self.dataset), | |||||
| ids=uuids, | |||||
| content_payload_key='page_content', | |||||
| group_id=self.dataset.id, | |||||
| group_payload_key='group_id', | |||||
| hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, | |||||
| max_indexing_threads=0, on_disk=False), | |||||
| **self._client_config.to_qdrant_params() | |||||
| ) | |||||
| return self | |||||
| def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: | |||||
| uuids = self._get_uuids(texts) | |||||
| self._vector_store = QdrantVectorStore.from_documents( | |||||
| texts, | |||||
| self._embeddings, | |||||
| collection_name=collection_name, | |||||
| ids=uuids, | |||||
| content_payload_key='page_content', | |||||
| group_id=self.dataset.id, | |||||
| group_payload_key='group_id', | |||||
| hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, | |||||
| max_indexing_threads=0, on_disk=False), | |||||
| **self._client_config.to_qdrant_params() | |||||
| ) | |||||
| return self | |||||
| def _get_vector_store(self) -> VectorStore: | |||||
| """Only for created index.""" | |||||
| if self._vector_store: | |||||
| return self._vector_store | |||||
| attributes = ['doc_id', 'dataset_id', 'document_id'] | |||||
| client = qdrant_client.QdrantClient( | |||||
| **self._client_config.to_qdrant_params() | |||||
| ) | |||||
| return QdrantVectorStore( | |||||
| client=client, | |||||
| collection_name=self.get_index_name(self.dataset), | |||||
| embeddings=self._embeddings, | |||||
| content_payload_key='page_content', | |||||
| group_id=self.dataset.id, | |||||
| group_payload_key='group_id' | |||||
| ) | |||||
| def _get_vector_store_class(self) -> type: | |||||
| return QdrantVectorStore | |||||
| def delete_by_document_id(self, document_id: str): | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| from qdrant_client.http import models | |||||
| vector_store.del_texts(models.Filter( | |||||
| must=[ | |||||
| models.FieldCondition( | |||||
| key="metadata.document_id", | |||||
| match=models.MatchValue(value=document_id), | |||||
| ), | |||||
| ], | |||||
| )) | |||||
| def delete_by_metadata_field(self, key: str, value: str): | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| from qdrant_client.http import models | |||||
| vector_store.del_texts(models.Filter( | |||||
| must=[ | |||||
| models.FieldCondition( | |||||
| key=f"metadata.{key}", | |||||
| match=models.MatchValue(value=value), | |||||
| ), | |||||
| ], | |||||
| )) | |||||
| def delete_by_ids(self, ids: list[str]) -> None: | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| from qdrant_client.http import models | |||||
| for node_id in ids: | |||||
| vector_store.del_texts(models.Filter( | |||||
| must=[ | |||||
| models.FieldCondition( | |||||
| key="metadata.doc_id", | |||||
| match=models.MatchValue(value=node_id), | |||||
| ), | |||||
| ], | |||||
| )) | |||||
| def delete_by_group_id(self, group_id: str) -> None: | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| from qdrant_client.http import models | |||||
| vector_store.del_texts(models.Filter( | |||||
| must=[ | |||||
| models.FieldCondition( | |||||
| key="group_id", | |||||
| match=models.MatchValue(value=group_id), | |||||
| ), | |||||
| ], | |||||
| )) | |||||
| def delete(self) -> None: | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| from qdrant_client.http import models | |||||
| vector_store.del_texts(models.Filter( | |||||
| must=[ | |||||
| models.FieldCondition( | |||||
| key="group_id", | |||||
| match=models.MatchValue(value=self.dataset.id), | |||||
| ), | |||||
| ], | |||||
| )) | |||||
| def _is_origin(self): | |||||
| if self.dataset.index_struct_dict: | |||||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| if not class_prefix.endswith('_Node'): | |||||
| # original class_prefix | |||||
| return True | |||||
| return False | |||||
| def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| from qdrant_client.http import models | |||||
| return vector_store.similarity_search_by_bm25(models.Filter( | |||||
| must=[ | |||||
| models.FieldCondition( | |||||
| key="group_id", | |||||
| match=models.MatchValue(value=self.dataset.id), | |||||
| ), | |||||
| models.FieldCondition( | |||||
| key="page_content", | |||||
| match=models.MatchText(text=query), | |||||
| ) | |||||
| ], | |||||
| ), kwargs.get('top_k', 2)) |
| import json | |||||
| from flask import current_app | |||||
| from langchain.embeddings.base import Embeddings | |||||
| from core.index.vector_index.base import BaseVectorIndex | |||||
| from extensions.ext_database import db | |||||
| from models.dataset import Dataset, Document | |||||
| class VectorIndex: | |||||
| def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings, | |||||
| attributes: list = None): | |||||
| if attributes is None: | |||||
| attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] | |||||
| self._dataset = dataset | |||||
| self._embeddings = embeddings | |||||
| self._vector_index = self._init_vector_index(dataset, config, embeddings, attributes) | |||||
| self._attributes = attributes | |||||
| def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings, | |||||
| attributes: list) -> BaseVectorIndex: | |||||
| vector_type = config.get('VECTOR_STORE') | |||||
| if self._dataset.index_struct_dict: | |||||
| vector_type = self._dataset.index_struct_dict['type'] | |||||
| if not vector_type: | |||||
| raise ValueError("Vector store must be specified.") | |||||
| if vector_type == "weaviate": | |||||
| from core.index.vector_index.weaviate_vector_index import WeaviateConfig, WeaviateVectorIndex | |||||
| return WeaviateVectorIndex( | |||||
| dataset=dataset, | |||||
| config=WeaviateConfig( | |||||
| endpoint=config.get('WEAVIATE_ENDPOINT'), | |||||
| api_key=config.get('WEAVIATE_API_KEY'), | |||||
| batch_size=int(config.get('WEAVIATE_BATCH_SIZE')) | |||||
| ), | |||||
| embeddings=embeddings, | |||||
| attributes=attributes | |||||
| ) | |||||
| elif vector_type == "qdrant": | |||||
| from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex | |||||
| return QdrantVectorIndex( | |||||
| dataset=dataset, | |||||
| config=QdrantConfig( | |||||
| endpoint=config.get('QDRANT_URL'), | |||||
| api_key=config.get('QDRANT_API_KEY'), | |||||
| root_path=current_app.root_path, | |||||
| timeout=config.get('QDRANT_CLIENT_TIMEOUT') | |||||
| ), | |||||
| embeddings=embeddings | |||||
| ) | |||||
| elif vector_type == "milvus": | |||||
| from core.index.vector_index.milvus_vector_index import MilvusConfig, MilvusVectorIndex | |||||
| return MilvusVectorIndex( | |||||
| dataset=dataset, | |||||
| config=MilvusConfig( | |||||
| host=config.get('MILVUS_HOST'), | |||||
| port=config.get('MILVUS_PORT'), | |||||
| user=config.get('MILVUS_USER'), | |||||
| password=config.get('MILVUS_PASSWORD'), | |||||
| secure=config.get('MILVUS_SECURE'), | |||||
| ), | |||||
| embeddings=embeddings | |||||
| ) | |||||
| else: | |||||
| raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") | |||||
| def add_texts(self, texts: list[Document], **kwargs): | |||||
| if not self._dataset.index_struct_dict: | |||||
| self._vector_index.create(texts, **kwargs) | |||||
| self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct()) | |||||
| db.session.commit() | |||||
| return | |||||
| self._vector_index.add_texts(texts, **kwargs) | |||||
| def __getattr__(self, name): | |||||
| if self._vector_index is not None: | |||||
| method = getattr(self._vector_index, name) | |||||
| if callable(method): | |||||
| return method | |||||
| raise AttributeError(f"'VectorIndex' object has no attribute '{name}'") | |||||
| from typing import Any, Optional, cast | |||||
| import requests | |||||
| import weaviate | |||||
| from langchain.embeddings.base import Embeddings | |||||
| from langchain.schema import Document | |||||
| from langchain.vectorstores import VectorStore | |||||
| from pydantic import BaseModel, root_validator | |||||
| from core.index.base import BaseIndex | |||||
| from core.index.vector_index.base import BaseVectorIndex | |||||
| from core.vector_store.weaviate_vector_store import WeaviateVectorStore | |||||
| from models.dataset import Dataset | |||||
| class WeaviateConfig(BaseModel): | |||||
| endpoint: str | |||||
| api_key: Optional[str] | |||||
| batch_size: int = 100 | |||||
| @root_validator() | |||||
| def validate_config(cls, values: dict) -> dict: | |||||
| if not values['endpoint']: | |||||
| raise ValueError("config WEAVIATE_ENDPOINT is required") | |||||
| return values | |||||
| class WeaviateVectorIndex(BaseVectorIndex): | |||||
| def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings, attributes: list): | |||||
| super().__init__(dataset, embeddings) | |||||
| self._client = self._init_client(config) | |||||
| self._attributes = attributes | |||||
| def _init_client(self, config: WeaviateConfig) -> weaviate.Client: | |||||
| auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key) | |||||
| weaviate.connect.connection.has_grpc = False | |||||
| try: | |||||
| client = weaviate.Client( | |||||
| url=config.endpoint, | |||||
| auth_client_secret=auth_config, | |||||
| timeout_config=(5, 60), | |||||
| startup_period=None | |||||
| ) | |||||
| except requests.exceptions.ConnectionError: | |||||
| raise ConnectionError("Vector database connection error") | |||||
| client.batch.configure( | |||||
| # `batch_size` takes an `int` value to enable auto-batching | |||||
| # (`None` is used for manual batching) | |||||
| batch_size=config.batch_size, | |||||
| # dynamically update the `batch_size` based on import speed | |||||
| dynamic=True, | |||||
| # `timeout_retries` takes an `int` value to retry on time outs | |||||
| timeout_retries=3, | |||||
| ) | |||||
| return client | |||||
| def get_type(self) -> str: | |||||
| return 'weaviate' | |||||
| def get_index_name(self, dataset: Dataset) -> str: | |||||
| if self.dataset.index_struct_dict: | |||||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| if not class_prefix.endswith('_Node'): | |||||
| # original class_prefix | |||||
| class_prefix += '_Node' | |||||
| return class_prefix | |||||
| dataset_id = dataset.id | |||||
| return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' | |||||
| def to_index_struct(self) -> dict: | |||||
| return { | |||||
| "type": self.get_type(), | |||||
| "vector_store": {"class_prefix": self.get_index_name(self.dataset)} | |||||
| } | |||||
| def create(self, texts: list[Document], **kwargs) -> BaseIndex: | |||||
| uuids = self._get_uuids(texts) | |||||
| self._vector_store = WeaviateVectorStore.from_documents( | |||||
| texts, | |||||
| self._embeddings, | |||||
| client=self._client, | |||||
| index_name=self.get_index_name(self.dataset), | |||||
| uuids=uuids, | |||||
| by_text=False | |||||
| ) | |||||
| return self | |||||
| def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: | |||||
| uuids = self._get_uuids(texts) | |||||
| self._vector_store = WeaviateVectorStore.from_documents( | |||||
| texts, | |||||
| self._embeddings, | |||||
| client=self._client, | |||||
| index_name=self.get_index_name(self.dataset), | |||||
| uuids=uuids, | |||||
| by_text=False | |||||
| ) | |||||
| return self | |||||
| def _get_vector_store(self) -> VectorStore: | |||||
| """Only for created index.""" | |||||
| if self._vector_store: | |||||
| return self._vector_store | |||||
| attributes = self._attributes | |||||
| if self._is_origin(): | |||||
| attributes = ['doc_id'] | |||||
| return WeaviateVectorStore( | |||||
| client=self._client, | |||||
| index_name=self.get_index_name(self.dataset), | |||||
| text_key='text', | |||||
| embedding=self._embeddings, | |||||
| attributes=attributes, | |||||
| by_text=False | |||||
| ) | |||||
| def _get_vector_store_class(self) -> type: | |||||
| return WeaviateVectorStore | |||||
| def delete_by_document_id(self, document_id: str): | |||||
| if self._is_origin(): | |||||
| self.recreate_dataset(self.dataset) | |||||
| return | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| vector_store.del_texts({ | |||||
| "operator": "Equal", | |||||
| "path": ["document_id"], | |||||
| "valueText": document_id | |||||
| }) | |||||
| def delete_by_metadata_field(self, key: str, value: str): | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| vector_store.del_texts({ | |||||
| "operator": "Equal", | |||||
| "path": [key], | |||||
| "valueText": value | |||||
| }) | |||||
| def delete_by_group_id(self, group_id: str): | |||||
| if self._is_origin(): | |||||
| self.recreate_dataset(self.dataset) | |||||
| return | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| vector_store.delete() | |||||
| def _is_origin(self): | |||||
| if self.dataset.index_struct_dict: | |||||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| if not class_prefix.endswith('_Node'): | |||||
| # original class_prefix | |||||
| return True | |||||
| return False | |||||
| def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs) | |||||
| from flask import Flask, current_app | from flask import Flask, current_app | ||||
| from flask_login import current_user | from flask_login import current_user | ||||
| from langchain.schema import Document | |||||
| from langchain.text_splitter import TextSplitter | from langchain.text_splitter import TextSplitter | ||||
| from sqlalchemy.orm.exc import ObjectDeletedError | from sqlalchemy.orm.exc import ObjectDeletedError | ||||
| from core.data_loader.file_extractor import FileExtractor | |||||
| from core.data_loader.loader.notion import NotionLoader | |||||
| from core.docstore.dataset_docstore import DatasetDocumentStore | from core.docstore.dataset_docstore import DatasetDocumentStore | ||||
| from core.errors.error import ProviderTokenNotInitError | from core.errors.error import ProviderTokenNotInitError | ||||
| from core.generator.llm_generator import LLMGenerator | from core.generator.llm_generator import LLMGenerator | ||||
| from core.index.index import IndexBuilder | |||||
| from core.model_manager import ModelInstance, ModelManager | from core.model_manager import ModelInstance, ModelManager | ||||
| from core.model_runtime.entities.model_entities import ModelType, PriceType | from core.model_runtime.entities.model_entities import ModelType, PriceType | ||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | ||||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | ||||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||||
| from core.rag.index_processor.index_processor_base import BaseIndexProcessor | |||||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||||
| from core.rag.models.document import Document | |||||
| from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter | from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from models.dataset import Dataset, DatasetProcessRule, DocumentSegment | from models.dataset import Dataset, DatasetProcessRule, DocumentSegment | ||||
| from models.dataset import Document as DatasetDocument | from models.dataset import Document as DatasetDocument | ||||
| from models.model import UploadFile | from models.model import UploadFile | ||||
| from models.source import DataSourceBinding | |||||
| from services.feature_service import FeatureService | from services.feature_service import FeatureService | ||||
| processing_rule = db.session.query(DatasetProcessRule). \ | processing_rule = db.session.query(DatasetProcessRule). \ | ||||
| filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ | filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ | ||||
| first() | first() | ||||
| # load file | |||||
| text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic') | |||||
| # get embedding model instance | |||||
| embedding_model_instance = None | |||||
| if dataset.indexing_technique == 'high_quality': | |||||
| if dataset.embedding_model_provider: | |||||
| embedding_model_instance = self.model_manager.get_model_instance( | |||||
| tenant_id=dataset.tenant_id, | |||||
| provider=dataset.embedding_model_provider, | |||||
| model_type=ModelType.TEXT_EMBEDDING, | |||||
| model=dataset.embedding_model | |||||
| ) | |||||
| else: | |||||
| embedding_model_instance = self.model_manager.get_default_model_instance( | |||||
| tenant_id=dataset.tenant_id, | |||||
| model_type=ModelType.TEXT_EMBEDDING, | |||||
| ) | |||||
| # get splitter | |||||
| splitter = self._get_splitter(processing_rule, embedding_model_instance) | |||||
| # split to documents | |||||
| documents = self._step_split( | |||||
| text_docs=text_docs, | |||||
| splitter=splitter, | |||||
| dataset=dataset, | |||||
| dataset_document=dataset_document, | |||||
| processing_rule=processing_rule | |||||
| ) | |||||
| self._build_index( | |||||
| index_type = dataset_document.doc_form | |||||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | |||||
| # extract | |||||
| text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) | |||||
| # transform | |||||
| documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict()) | |||||
| # save segment | |||||
| self._load_segments(dataset, dataset_document, documents) | |||||
| # load | |||||
| self._load( | |||||
| index_processor=index_processor, | |||||
| dataset=dataset, | dataset=dataset, | ||||
| dataset_document=dataset_document, | dataset_document=dataset_document, | ||||
| documents=documents | documents=documents | ||||
| filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ | filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ | ||||
| first() | first() | ||||
| # load file | |||||
| text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic') | |||||
| # get embedding model instance | |||||
| embedding_model_instance = None | |||||
| if dataset.indexing_technique == 'high_quality': | |||||
| if dataset.embedding_model_provider: | |||||
| embedding_model_instance = self.model_manager.get_model_instance( | |||||
| tenant_id=dataset.tenant_id, | |||||
| provider=dataset.embedding_model_provider, | |||||
| model_type=ModelType.TEXT_EMBEDDING, | |||||
| model=dataset.embedding_model | |||||
| ) | |||||
| else: | |||||
| embedding_model_instance = self.model_manager.get_default_model_instance( | |||||
| tenant_id=dataset.tenant_id, | |||||
| model_type=ModelType.TEXT_EMBEDDING, | |||||
| ) | |||||
| # get splitter | |||||
| splitter = self._get_splitter(processing_rule, embedding_model_instance) | |||||
| index_type = dataset_document.doc_form | |||||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | |||||
| # extract | |||||
| text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) | |||||
| # split to documents | |||||
| documents = self._step_split( | |||||
| text_docs=text_docs, | |||||
| splitter=splitter, | |||||
| dataset=dataset, | |||||
| dataset_document=dataset_document, | |||||
| processing_rule=processing_rule | |||||
| ) | |||||
| # transform | |||||
| documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict()) | |||||
| # save segment | |||||
| self._load_segments(dataset, dataset_document, documents) | |||||
| # build index | |||||
| self._build_index( | |||||
| # load | |||||
| self._load( | |||||
| index_processor=index_processor, | |||||
| dataset=dataset, | dataset=dataset, | ||||
| dataset_document=dataset_document, | dataset_document=dataset_document, | ||||
| documents=documents | documents=documents | ||||
| documents.append(document) | documents.append(document) | ||||
| # build index | # build index | ||||
| self._build_index( | |||||
| # get the process rule | |||||
| processing_rule = db.session.query(DatasetProcessRule). \ | |||||
| filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ | |||||
| first() | |||||
| index_type = dataset_document.doc_form | |||||
| index_processor = IndexProcessorFactory(index_type, processing_rule.to_dict()).init_index_processor() | |||||
| self._load( | |||||
| index_processor=index_processor, | |||||
| dataset=dataset, | dataset=dataset, | ||||
| dataset_document=dataset_document, | dataset_document=dataset_document, | ||||
| documents=documents | documents=documents | ||||
| dataset_document.stopped_at = datetime.datetime.utcnow() | dataset_document.stopped_at = datetime.datetime.utcnow() | ||||
| db.session.commit() | db.session.commit() | ||||
| def file_indexing_estimate(self, tenant_id: str, file_details: list[UploadFile], tmp_processing_rule: dict, | |||||
| doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, | |||||
| indexing_technique: str = 'economy') -> dict: | |||||
| def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict, | |||||
| doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, | |||||
| indexing_technique: str = 'economy') -> dict: | |||||
| """ | """ | ||||
| Estimate the indexing for the document. | Estimate the indexing for the document. | ||||
| """ | """ | ||||
| # check document limit | # check document limit | ||||
| features = FeatureService.get_features(tenant_id) | features = FeatureService.get_features(tenant_id) | ||||
| if features.billing.enabled: | if features.billing.enabled: | ||||
| count = len(file_details) | |||||
| count = len(extract_settings) | |||||
| batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) | batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) | ||||
| if count > batch_upload_limit: | if count > batch_upload_limit: | ||||
| raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") | raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") | ||||
| total_segments = 0 | total_segments = 0 | ||||
| total_price = 0 | total_price = 0 | ||||
| currency = 'USD' | currency = 'USD' | ||||
| for file_detail in file_details: | |||||
| index_type = doc_form | |||||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | |||||
| all_text_docs = [] | |||||
| for extract_setting in extract_settings: | |||||
| # extract | |||||
| text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) | |||||
| all_text_docs.extend(text_docs) | |||||
| processing_rule = DatasetProcessRule( | processing_rule = DatasetProcessRule( | ||||
| mode=tmp_processing_rule["mode"], | mode=tmp_processing_rule["mode"], | ||||
| rules=json.dumps(tmp_processing_rule["rules"]) | rules=json.dumps(tmp_processing_rule["rules"]) | ||||
| ) | ) | ||||
| # load data from file | |||||
| text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic') | |||||
| # get splitter | # get splitter | ||||
| splitter = self._get_splitter(processing_rule, embedding_model_instance) | splitter = self._get_splitter(processing_rule, embedding_model_instance) | ||||
| ) | ) | ||||
| total_segments += len(documents) | total_segments += len(documents) | ||||
| for document in documents: | for document in documents: | ||||
| if len(preview_texts) < 5: | if len(preview_texts) < 5: | ||||
| preview_texts.append(document.page_content) | preview_texts.append(document.page_content) | ||||
| "preview": preview_texts | "preview": preview_texts | ||||
| } | } | ||||
| def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, | |||||
| doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, | |||||
| indexing_technique: str = 'economy') -> dict: | |||||
| """ | |||||
| Estimate the indexing for the document. | |||||
| """ | |||||
| # check document limit | |||||
| features = FeatureService.get_features(tenant_id) | |||||
| if features.billing.enabled: | |||||
| count = len(notion_info_list) | |||||
| batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) | |||||
| if count > batch_upload_limit: | |||||
| raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") | |||||
| embedding_model_instance = None | |||||
| if dataset_id: | |||||
| dataset = Dataset.query.filter_by( | |||||
| id=dataset_id | |||||
| ).first() | |||||
| if not dataset: | |||||
| raise ValueError('Dataset not found.') | |||||
| if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality': | |||||
| if dataset.embedding_model_provider: | |||||
| embedding_model_instance = self.model_manager.get_model_instance( | |||||
| tenant_id=tenant_id, | |||||
| provider=dataset.embedding_model_provider, | |||||
| model_type=ModelType.TEXT_EMBEDDING, | |||||
| model=dataset.embedding_model | |||||
| ) | |||||
| else: | |||||
| embedding_model_instance = self.model_manager.get_default_model_instance( | |||||
| tenant_id=tenant_id, | |||||
| model_type=ModelType.TEXT_EMBEDDING, | |||||
| ) | |||||
| else: | |||||
| if indexing_technique == 'high_quality': | |||||
| embedding_model_instance = self.model_manager.get_default_model_instance( | |||||
| tenant_id=tenant_id, | |||||
| model_type=ModelType.TEXT_EMBEDDING | |||||
| ) | |||||
| # load data from notion | |||||
| tokens = 0 | |||||
| preview_texts = [] | |||||
| total_segments = 0 | |||||
| total_price = 0 | |||||
| currency = 'USD' | |||||
| for notion_info in notion_info_list: | |||||
| workspace_id = notion_info['workspace_id'] | |||||
| data_source_binding = DataSourceBinding.query.filter( | |||||
| db.and_( | |||||
| DataSourceBinding.tenant_id == current_user.current_tenant_id, | |||||
| DataSourceBinding.provider == 'notion', | |||||
| DataSourceBinding.disabled == False, | |||||
| DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' | |||||
| ) | |||||
| ).first() | |||||
| if not data_source_binding: | |||||
| raise ValueError('Data source binding not found.') | |||||
| for page in notion_info['pages']: | |||||
| loader = NotionLoader( | |||||
| notion_access_token=data_source_binding.access_token, | |||||
| notion_workspace_id=workspace_id, | |||||
| notion_obj_id=page['page_id'], | |||||
| notion_page_type=page['type'] | |||||
| ) | |||||
| documents = loader.load() | |||||
| processing_rule = DatasetProcessRule( | |||||
| mode=tmp_processing_rule["mode"], | |||||
| rules=json.dumps(tmp_processing_rule["rules"]) | |||||
| ) | |||||
| # get splitter | |||||
| splitter = self._get_splitter(processing_rule, embedding_model_instance) | |||||
| # split to documents | |||||
| documents = self._split_to_documents_for_estimate( | |||||
| text_docs=documents, | |||||
| splitter=splitter, | |||||
| processing_rule=processing_rule | |||||
| ) | |||||
| total_segments += len(documents) | |||||
| embedding_model_type_instance = None | |||||
| if embedding_model_instance: | |||||
| embedding_model_type_instance = embedding_model_instance.model_type_instance | |||||
| embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) | |||||
| for document in documents: | |||||
| if len(preview_texts) < 5: | |||||
| preview_texts.append(document.page_content) | |||||
| if indexing_technique == 'high_quality' and embedding_model_type_instance: | |||||
| tokens += embedding_model_type_instance.get_num_tokens( | |||||
| model=embedding_model_instance.model, | |||||
| credentials=embedding_model_instance.credentials, | |||||
| texts=[document.page_content] | |||||
| ) | |||||
| if doc_form and doc_form == 'qa_model': | |||||
| model_instance = self.model_manager.get_default_model_instance( | |||||
| tenant_id=tenant_id, | |||||
| model_type=ModelType.LLM | |||||
| ) | |||||
| model_type_instance = model_instance.model_type_instance | |||||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||||
| if len(preview_texts) > 0: | |||||
| # qa model document | |||||
| response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], | |||||
| doc_language) | |||||
| document_qa_list = self.format_split_text(response) | |||||
| price_info = model_type_instance.get_price( | |||||
| model=model_instance.model, | |||||
| credentials=model_instance.credentials, | |||||
| price_type=PriceType.INPUT, | |||||
| tokens=total_segments * 2000, | |||||
| ) | |||||
| return { | |||||
| "total_segments": total_segments * 20, | |||||
| "tokens": total_segments * 2000, | |||||
| "total_price": '{:f}'.format(price_info.total_amount), | |||||
| "currency": price_info.currency, | |||||
| "qa_preview": document_qa_list, | |||||
| "preview": preview_texts | |||||
| } | |||||
| if embedding_model_instance: | |||||
| embedding_model_type_instance = embedding_model_instance.model_type_instance | |||||
| embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) | |||||
| embedding_price_info = embedding_model_type_instance.get_price( | |||||
| model=embedding_model_instance.model, | |||||
| credentials=embedding_model_instance.credentials, | |||||
| price_type=PriceType.INPUT, | |||||
| tokens=tokens | |||||
| ) | |||||
| total_price = '{:f}'.format(embedding_price_info.total_amount) | |||||
| currency = embedding_price_info.currency | |||||
| return { | |||||
| "total_segments": total_segments, | |||||
| "tokens": tokens, | |||||
| "total_price": total_price, | |||||
| "currency": currency, | |||||
| "preview": preview_texts | |||||
| } | |||||
| def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> list[Document]: | |||||
| def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \ | |||||
| -> list[Document]: | |||||
| # load file | # load file | ||||
| if dataset_document.data_source_type not in ["upload_file", "notion_import"]: | if dataset_document.data_source_type not in ["upload_file", "notion_import"]: | ||||
| return [] | return [] | ||||
| one_or_none() | one_or_none() | ||||
| if file_detail: | if file_detail: | ||||
| text_docs = FileExtractor.load(file_detail, is_automatic=automatic) | |||||
| extract_setting = ExtractSetting( | |||||
| datasource_type="upload_file", | |||||
| upload_file=file_detail, | |||||
| document_model=dataset_document.doc_form | |||||
| ) | |||||
| text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) | |||||
| elif dataset_document.data_source_type == 'notion_import': | elif dataset_document.data_source_type == 'notion_import': | ||||
| loader = NotionLoader.from_document(dataset_document) | |||||
| text_docs = loader.load() | |||||
| if (not data_source_info or 'notion_workspace_id' not in data_source_info | |||||
| or 'notion_page_id' not in data_source_info): | |||||
| raise ValueError("no notion import info found") | |||||
| extract_setting = ExtractSetting( | |||||
| datasource_type="notion_import", | |||||
| notion_info={ | |||||
| "notion_workspace_id": data_source_info['notion_workspace_id'], | |||||
| "notion_obj_id": data_source_info['notion_page_id'], | |||||
| "notion_page_type": data_source_info['notion_page_type'], | |||||
| "document": dataset_document | |||||
| }, | |||||
| document_model=dataset_document.doc_form | |||||
| ) | |||||
| text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) | |||||
| # update document status to splitting | # update document status to splitting | ||||
| self._update_document_index_status( | self._update_document_index_status( | ||||
| document_id=dataset_document.id, | document_id=dataset_document.id, | ||||
| # replace doc id to document model id | # replace doc id to document model id | ||||
| text_docs = cast(list[Document], text_docs) | text_docs = cast(list[Document], text_docs) | ||||
| for text_doc in text_docs: | for text_doc in text_docs: | ||||
| # remove invalid symbol | |||||
| text_doc.page_content = self.filter_string(text_doc.page_content) | |||||
| text_doc.metadata['document_id'] = dataset_document.id | text_doc.metadata['document_id'] = dataset_document.id | ||||
| text_doc.metadata['dataset_id'] = dataset_document.dataset_id | text_doc.metadata['dataset_id'] = dataset_document.dataset_id | ||||
| for q, a in matches if q and a | for q, a in matches if q and a | ||||
| ] | ] | ||||
| def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]) -> None: | |||||
| def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset, | |||||
| dataset_document: DatasetDocument, documents: list[Document]) -> None: | |||||
| """ | """ | ||||
| Build the index for the document. | |||||
| insert index and update document/segment status to completed | |||||
| """ | """ | ||||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||||
| keyword_table_index = IndexBuilder.get_index(dataset, 'economy') | |||||
| embedding_model_instance = None | embedding_model_instance = None | ||||
| if dataset.indexing_technique == 'high_quality': | if dataset.indexing_technique == 'high_quality': | ||||
| embedding_model_instance = self.model_manager.get_model_instance( | embedding_model_instance = self.model_manager.get_model_instance( | ||||
| ) | ) | ||||
| for document in chunk_documents | for document in chunk_documents | ||||
| ) | ) | ||||
| # save vector index | |||||
| if vector_index: | |||||
| vector_index.add_texts(chunk_documents) | |||||
| # save keyword index | |||||
| keyword_table_index.add_texts(chunk_documents) | |||||
| # load index | |||||
| index_processor.load(dataset, chunk_documents) | |||||
| document_ids = [document.metadata['doc_id'] for document in chunk_documents] | document_ids = [document.metadata['doc_id'] for document in chunk_documents] | ||||
| db.session.query(DocumentSegment).filter( | db.session.query(DocumentSegment).filter( | ||||
| ) | ) | ||||
| documents.append(document) | documents.append(document) | ||||
| # save vector index | # save vector index | ||||
| index = IndexBuilder.get_index(dataset, 'high_quality') | |||||
| if index: | |||||
| index.add_texts(documents, duplicate_check=True) | |||||
| # save keyword index | |||||
| index = IndexBuilder.get_index(dataset, 'economy') | |||||
| if index: | |||||
| index.add_texts(documents) | |||||
| index_type = dataset.doc_form | |||||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | |||||
| index_processor.load(dataset, documents) | |||||
| def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset, | |||||
| text_docs: list[Document], process_rule: dict) -> list[Document]: | |||||
| # get embedding model instance | |||||
| embedding_model_instance = None | |||||
| if dataset.indexing_technique == 'high_quality': | |||||
| if dataset.embedding_model_provider: | |||||
| embedding_model_instance = self.model_manager.get_model_instance( | |||||
| tenant_id=dataset.tenant_id, | |||||
| provider=dataset.embedding_model_provider, | |||||
| model_type=ModelType.TEXT_EMBEDDING, | |||||
| model=dataset.embedding_model | |||||
| ) | |||||
| else: | |||||
| embedding_model_instance = self.model_manager.get_default_model_instance( | |||||
| tenant_id=dataset.tenant_id, | |||||
| model_type=ModelType.TEXT_EMBEDDING, | |||||
| ) | |||||
| documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance, | |||||
| process_rule=process_rule) | |||||
| return documents | |||||
| def _load_segments(self, dataset, dataset_document, documents): | |||||
| # save node to document segment | |||||
| doc_store = DatasetDocumentStore( | |||||
| dataset=dataset, | |||||
| user_id=dataset_document.created_by, | |||||
| document_id=dataset_document.id | |||||
| ) | |||||
| # add document segments | |||||
| doc_store.add_documents(documents) | |||||
| # update document status to indexing | |||||
| cur_time = datetime.datetime.utcnow() | |||||
| self._update_document_index_status( | |||||
| document_id=dataset_document.id, | |||||
| after_indexing_status="indexing", | |||||
| extra_update_params={ | |||||
| DatasetDocument.cleaning_completed_at: cur_time, | |||||
| DatasetDocument.splitting_completed_at: cur_time, | |||||
| } | |||||
| ) | |||||
| # update segment status to indexing | |||||
| self._update_segments_by_document( | |||||
| dataset_document_id=dataset_document.id, | |||||
| update_params={ | |||||
| DocumentSegment.status: "indexing", | |||||
| DocumentSegment.indexing_at: datetime.datetime.utcnow() | |||||
| } | |||||
| ) | |||||
| pass | |||||
| class DocumentIsPausedException(Exception): | class DocumentIsPausedException(Exception): |
| import re | |||||
| class CleanProcessor: | |||||
| @classmethod | |||||
| def clean(cls, text: str, process_rule: dict) -> str: | |||||
| # default clean | |||||
| # remove invalid symbol | |||||
| text = re.sub(r'<\|', '<', text) | |||||
| text = re.sub(r'\|>', '>', text) | |||||
| text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text) | |||||
| # Unicode U+FFFE | |||||
| text = re.sub('\uFFFE', '', text) | |||||
| rules = process_rule['rules'] if process_rule else None | |||||
| if 'pre_processing_rules' in rules: | |||||
| pre_processing_rules = rules["pre_processing_rules"] | |||||
| for pre_processing_rule in pre_processing_rules: | |||||
| if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True: | |||||
| # Remove extra spaces | |||||
| pattern = r'\n{3,}' | |||||
| text = re.sub(pattern, '\n\n', text) | |||||
| pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}' | |||||
| text = re.sub(pattern, ' ', text) | |||||
| elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True: | |||||
| # Remove email | |||||
| pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)' | |||||
| text = re.sub(pattern, '', text) | |||||
| # Remove URL | |||||
| pattern = r'https?://[^\s]+' | |||||
| text = re.sub(pattern, '', text) | |||||
| return text | |||||
| def filter_string(self, text): | |||||
| return text |
| """Abstract interface for document cleaner implementations.""" | |||||
| from abc import ABC, abstractmethod | |||||
| class BaseCleaner(ABC): | |||||
| """Interface for clean chunk content. | |||||
| """ | |||||
| @abstractmethod | |||||
| def clean(self, content: str): | |||||
| raise NotImplementedError | |||||
| """Abstract interface for document clean implementations.""" | |||||
| from core.rag.cleaner.cleaner_base import BaseCleaner | |||||
| class UnstructuredNonAsciiCharsCleaner(BaseCleaner): | |||||
| def clean(self, content) -> str: | |||||
| """clean document content.""" | |||||
| from unstructured.cleaners.core import clean_extra_whitespace | |||||
| # Returns "ITEM 1A: RISK FACTORS" | |||||
| return clean_extra_whitespace(content) |
| """Abstract interface for document clean implementations.""" | |||||
| from core.rag.cleaner.cleaner_base import BaseCleaner | |||||
| class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner): | |||||
| def clean(self, content) -> str: | |||||
| """clean document content.""" | |||||
| import re | |||||
| from unstructured.cleaners.core import group_broken_paragraphs | |||||
| para_split_re = re.compile(r"(\s*\n\s*){3}") | |||||
| return group_broken_paragraphs(content, paragraph_split=para_split_re) |
| """Abstract interface for document clean implementations.""" | |||||
| from core.rag.cleaner.cleaner_base import BaseCleaner | |||||
| class UnstructuredNonAsciiCharsCleaner(BaseCleaner): | |||||
| def clean(self, content) -> str: | |||||
| """clean document content.""" | |||||
| from unstructured.cleaners.core import clean_non_ascii_chars | |||||
| # Returns "This text containsnon-ascii characters!" | |||||
| return clean_non_ascii_chars(content) |
| """Abstract interface for document clean implementations.""" | |||||
| from core.rag.cleaner.cleaner_base import BaseCleaner | |||||
| class UnstructuredNonAsciiCharsCleaner(BaseCleaner): | |||||
| def clean(self, content) -> str: | |||||
| """Replaces unicode quote characters, such as the \x91 character in a string.""" | |||||
| from unstructured.cleaners.core import replace_unicode_quotes | |||||
| return replace_unicode_quotes(content) |
| """Abstract interface for document clean implementations.""" | |||||
| from core.rag.cleaner.cleaner_base import BaseCleaner | |||||
| class UnstructuredTranslateTextCleaner(BaseCleaner): | |||||
| def clean(self, content) -> str: | |||||
| """clean document content.""" | |||||
| from unstructured.cleaners.translate import translate_text | |||||
| return translate_text(content) |
| from typing import Optional | |||||
| from core.model_manager import ModelManager | |||||
| from core.model_runtime.entities.model_entities import ModelType | |||||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||||
| from core.rag.data_post_processor.reorder import ReorderRunner | |||||
| from core.rag.models.document import Document | |||||
| from core.rerank.rerank import RerankRunner | |||||
| class DataPostProcessor: | |||||
| """Interface for data post-processing document. | |||||
| """ | |||||
| def __init__(self, tenant_id: str, reranking_model: dict, reorder_enabled: bool = False): | |||||
| self.rerank_runner = self._get_rerank_runner(reranking_model, tenant_id) | |||||
| self.reorder_runner = self._get_reorder_runner(reorder_enabled) | |||||
| def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, | |||||
| top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: | |||||
| if self.rerank_runner: | |||||
| documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user) | |||||
| if self.reorder_runner: | |||||
| documents = self.reorder_runner.run(documents) | |||||
| return documents | |||||
| def _get_rerank_runner(self, reranking_model: dict, tenant_id: str) -> Optional[RerankRunner]: | |||||
| if reranking_model: | |||||
| try: | |||||
| model_manager = ModelManager() | |||||
| rerank_model_instance = model_manager.get_model_instance( | |||||
| tenant_id=tenant_id, | |||||
| provider=reranking_model['reranking_provider_name'], | |||||
| model_type=ModelType.RERANK, | |||||
| model=reranking_model['reranking_model_name'] | |||||
| ) | |||||
| except InvokeAuthorizationError: | |||||
| return None | |||||
| return RerankRunner(rerank_model_instance) | |||||
| return None | |||||
| def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]: | |||||
| if reorder_enabled: | |||||
| return ReorderRunner() | |||||
| return None | |||||
| from langchain.schema import Document | |||||
| class ReorderRunner: | |||||
| def run(self, documents: list[Document]) -> list[Document]: | |||||
| # Retrieve elements from odd indices (0, 2, 4, etc.) of the documents list | |||||
| odd_elements = documents[::2] | |||||
| # Retrieve elements from even indices (1, 3, 5, etc.) of the documents list | |||||
| even_elements = documents[1::2] | |||||
| # Reverse the list of elements from even indices | |||||
| even_elements_reversed = even_elements[::-1] | |||||
| new_documents = odd_elements + even_elements_reversed | |||||
| return new_documents |
| from abc import ABC, abstractmethod | |||||
| class Embeddings(ABC): | |||||
| """Interface for embedding models.""" | |||||
| @abstractmethod | |||||
| def embed_documents(self, texts: list[str]) -> list[list[float]]: | |||||
| """Embed search docs.""" | |||||
| @abstractmethod | |||||
| def embed_query(self, text: str) -> list[float]: | |||||
| """Embed query text.""" | |||||
| async def aembed_documents(self, texts: list[str]) -> list[list[float]]: | |||||
| """Asynchronous Embed search docs.""" | |||||
| raise NotImplementedError | |||||
| async def aembed_query(self, text: str) -> list[float]: | |||||
| """Asynchronous Embed query text.""" | |||||
| raise NotImplementedError |
| from collections import defaultdict | from collections import defaultdict | ||||
| from typing import Any, Optional | from typing import Any, Optional | ||||
| from langchain.schema import BaseRetriever, Document | |||||
| from pydantic import BaseModel, Extra, Field | |||||
| from pydantic import BaseModel | |||||
| from core.index.base import BaseIndex | |||||
| from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler | |||||
| from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler | |||||
| from core.rag.datasource.keyword.keyword_base import BaseKeyword | |||||
| from core.rag.models.document import Document | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment | from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment | ||||
| max_keywords_per_chunk: int = 10 | max_keywords_per_chunk: int = 10 | ||||
| class KeywordTableIndex(BaseIndex): | |||||
| def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()): | |||||
| class Jieba(BaseKeyword): | |||||
| def __init__(self, dataset: Dataset): | |||||
| super().__init__(dataset) | super().__init__(dataset) | ||||
| self._config = config | |||||
| self._config = KeywordTableConfig() | |||||
| def create(self, texts: list[Document], **kwargs) -> BaseIndex: | |||||
| def create(self, texts: list[Document], **kwargs) -> BaseKeyword: | |||||
| keyword_table_handler = JiebaKeywordTableHandler() | keyword_table_handler = JiebaKeywordTableHandler() | ||||
| keyword_table = {} | |||||
| for text in texts: | |||||
| keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) | |||||
| self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) | |||||
| keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) | |||||
| dataset_keyword_table = DatasetKeywordTable( | |||||
| dataset_id=self.dataset.id, | |||||
| keyword_table=json.dumps({ | |||||
| '__type__': 'keyword_table', | |||||
| '__data__': { | |||||
| "index_id": self.dataset.id, | |||||
| "summary": None, | |||||
| "table": {} | |||||
| } | |||||
| }, cls=SetEncoder) | |||||
| ) | |||||
| db.session.add(dataset_keyword_table) | |||||
| db.session.commit() | |||||
| self._save_dataset_keyword_table(keyword_table) | |||||
| return self | |||||
| def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: | |||||
| keyword_table_handler = JiebaKeywordTableHandler() | |||||
| keyword_table = {} | |||||
| keyword_table = self._get_dataset_keyword_table() | |||||
| for text in texts: | for text in texts: | ||||
| keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) | keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) | ||||
| self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) | self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) | ||||
| keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) | keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) | ||||
| dataset_keyword_table = DatasetKeywordTable( | |||||
| dataset_id=self.dataset.id, | |||||
| keyword_table=json.dumps({ | |||||
| '__type__': 'keyword_table', | |||||
| '__data__': { | |||||
| "index_id": self.dataset.id, | |||||
| "summary": None, | |||||
| "table": {} | |||||
| } | |||||
| }, cls=SetEncoder) | |||||
| ) | |||||
| db.session.add(dataset_keyword_table) | |||||
| db.session.commit() | |||||
| self._save_dataset_keyword_table(keyword_table) | self._save_dataset_keyword_table(keyword_table) | ||||
| return self | return self | ||||
| keyword_table_handler = JiebaKeywordTableHandler() | keyword_table_handler = JiebaKeywordTableHandler() | ||||
| keyword_table = self._get_dataset_keyword_table() | keyword_table = self._get_dataset_keyword_table() | ||||
| for text in texts: | |||||
| keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) | |||||
| keywords_list = kwargs.get('keywords_list', None) | |||||
| for i in range(len(texts)): | |||||
| text = texts[i] | |||||
| if keywords_list: | |||||
| keywords = keywords_list[i] | |||||
| else: | |||||
| keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) | |||||
| self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) | self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) | ||||
| keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) | keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) | ||||
| self._save_dataset_keyword_table(keyword_table) | self._save_dataset_keyword_table(keyword_table) | ||||
| def delete_by_metadata_field(self, key: str, value: str): | |||||
| pass | |||||
| def get_retriever(self, **kwargs: Any) -> BaseRetriever: | |||||
| return KeywordTableRetriever(index=self, **kwargs) | |||||
| def search( | def search( | ||||
| self, query: str, | self, query: str, | ||||
| **kwargs: Any | **kwargs: Any | ||||
| ) -> list[Document]: | ) -> list[Document]: | ||||
| keyword_table = self._get_dataset_keyword_table() | keyword_table = self._get_dataset_keyword_table() | ||||
| search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {} | |||||
| k = search_kwargs.get('k') if search_kwargs.get('k') else 4 | |||||
| k = kwargs.get('top_k', 4) | |||||
| sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k) | sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k) | ||||
| db.session.delete(dataset_keyword_table) | db.session.delete(dataset_keyword_table) | ||||
| db.session.commit() | db.session.commit() | ||||
| def delete_by_group_id(self, group_id: str) -> None: | |||||
| dataset_keyword_table = self.dataset.dataset_keyword_table | |||||
| if dataset_keyword_table: | |||||
| db.session.delete(dataset_keyword_table) | |||||
| db.session.commit() | |||||
| def _save_dataset_keyword_table(self, keyword_table): | def _save_dataset_keyword_table(self, keyword_table): | ||||
| keyword_table_dict = { | keyword_table_dict = { | ||||
| '__type__': 'keyword_table', | '__type__': 'keyword_table', | ||||
| ).first() | ).first() | ||||
| if document_segment: | if document_segment: | ||||
| document_segment.keywords = keywords | document_segment.keywords = keywords | ||||
| db.session.add(document_segment) | |||||
| db.session.commit() | db.session.commit() | ||||
| def create_segment_keywords(self, node_id: str, keywords: list[str]): | def create_segment_keywords(self, node_id: str, keywords: list[str]): | ||||
| self._save_dataset_keyword_table(keyword_table) | self._save_dataset_keyword_table(keyword_table) | ||||
| class KeywordTableRetriever(BaseRetriever, BaseModel): | |||||
| index: KeywordTableIndex | |||||
| search_kwargs: dict = Field(default_factory=dict) | |||||
| class Config: | |||||
| """Configuration for this pydantic object.""" | |||||
| extra = Extra.forbid | |||||
| arbitrary_types_allowed = True | |||||
| def get_relevant_documents(self, query: str) -> list[Document]: | |||||
| """Get documents relevant for a query. | |||||
| Args: | |||||
| query: string to find relevant documents for | |||||
| Returns: | |||||
| List of relevant documents | |||||
| """ | |||||
| return self.index.search(query, **self.search_kwargs) | |||||
| async def aget_relevant_documents(self, query: str) -> list[Document]: | |||||
| raise NotImplementedError("KeywordTableRetriever does not support async") | |||||
| class SetEncoder(json.JSONEncoder): | class SetEncoder(json.JSONEncoder): | ||||
| def default(self, obj): | def default(self, obj): | ||||
| if isinstance(obj, set): | if isinstance(obj, set): |
| import jieba | import jieba | ||||
| from jieba.analyse import default_tfidf | from jieba.analyse import default_tfidf | ||||
| from core.index.keyword_table_index.stopwords import STOPWORDS | |||||
| from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS | |||||
| class JiebaKeywordTableHandler: | class JiebaKeywordTableHandler: |
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
| from typing import Any | from typing import Any | ||||
| from langchain.schema import BaseRetriever, Document | |||||
| from core.rag.models.document import Document | |||||
| from models.dataset import Dataset | from models.dataset import Dataset | ||||
| class BaseIndex(ABC): | |||||
| class BaseKeyword(ABC): | |||||
| def __init__(self, dataset: Dataset): | def __init__(self, dataset: Dataset): | ||||
| self.dataset = dataset | self.dataset = dataset | ||||
| @abstractmethod | @abstractmethod | ||||
| def create(self, texts: list[Document], **kwargs) -> BaseIndex: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: | |||||
| def create(self, texts: list[Document], **kwargs) -> BaseKeyword: | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @abstractmethod | @abstractmethod | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @abstractmethod | @abstractmethod | ||||
| def delete_by_metadata_field(self, key: str, value: str) -> None: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def delete_by_group_id(self, group_id: str) -> None: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def delete_by_document_id(self, document_id: str): | |||||
| def delete_by_document_id(self, document_id: str) -> None: | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @abstractmethod | |||||
| def get_retriever(self, **kwargs: Any) -> BaseRetriever: | |||||
| def delete(self) -> None: | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @abstractmethod | |||||
| def search( | def search( | ||||
| self, query: str, | self, query: str, | ||||
| **kwargs: Any | **kwargs: Any | ||||
| ) -> list[Document]: | ) -> list[Document]: | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| def delete(self) -> None: | |||||
| raise NotImplementedError | |||||
| def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: | def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: | ||||
| for text in texts: | for text in texts: | ||||
| doc_id = text.metadata['doc_id'] | doc_id = text.metadata['doc_id'] |
| from typing import Any, cast | |||||
| from flask import current_app | |||||
| from core.rag.datasource.keyword.jieba.jieba import Jieba | |||||
| from core.rag.datasource.keyword.keyword_base import BaseKeyword | |||||
| from core.rag.models.document import Document | |||||
| from models.dataset import Dataset | |||||
| class Keyword: | |||||
| def __init__(self, dataset: Dataset): | |||||
| self._dataset = dataset | |||||
| self._keyword_processor = self._init_keyword() | |||||
| def _init_keyword(self) -> BaseKeyword: | |||||
| config = cast(dict, current_app.config) | |||||
| keyword_type = config.get('KEYWORD_STORE') | |||||
| if not keyword_type: | |||||
| raise ValueError("Keyword store must be specified.") | |||||
| if keyword_type == "jieba": | |||||
| return Jieba( | |||||
| dataset=self._dataset | |||||
| ) | |||||
| else: | |||||
| raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") | |||||
| def create(self, texts: list[Document], **kwargs): | |||||
| self._keyword_processor.create(texts, **kwargs) | |||||
| def add_texts(self, texts: list[Document], **kwargs): | |||||
| self._keyword_processor.add_texts(texts, **kwargs) | |||||
| def text_exists(self, id: str) -> bool: | |||||
| return self._keyword_processor.text_exists(id) | |||||
| def delete_by_ids(self, ids: list[str]) -> None: | |||||
| self._keyword_processor.delete_by_ids(ids) | |||||
| def delete_by_document_id(self, document_id: str) -> None: | |||||
| self._keyword_processor.delete_by_document_id(document_id) | |||||
| def delete(self) -> None: | |||||
| self._keyword_processor.delete() | |||||
| def search( | |||||
| self, query: str, | |||||
| **kwargs: Any | |||||
| ) -> list[Document]: | |||||
| return self._keyword_processor.search(query, **kwargs) | |||||
| def __getattr__(self, name): | |||||
| if self._keyword_processor is not None: | |||||
| method = getattr(self._keyword_processor, name) | |||||
| if callable(method): | |||||
| return method | |||||
| raise AttributeError(f"'Keyword' object has no attribute '{name}'") |
| import threading | |||||
| from typing import Optional | |||||
| from flask import Flask, current_app | |||||
| from flask_login import current_user | |||||
| from core.rag.data_post_processor.data_post_processor import DataPostProcessor | |||||
| from core.rag.datasource.keyword.keyword_factory import Keyword | |||||
| from core.rag.datasource.vdb.vector_factory import Vector | |||||
| from extensions.ext_database import db | |||||
| from models.dataset import Dataset | |||||
| default_retrieval_model = { | |||||
| 'search_method': 'semantic_search', | |||||
| 'reranking_enable': False, | |||||
| 'reranking_model': { | |||||
| 'reranking_provider_name': '', | |||||
| 'reranking_model_name': '' | |||||
| }, | |||||
| 'top_k': 2, | |||||
| 'score_threshold_enabled': False | |||||
| } | |||||
| class RetrievalService: | |||||
| @classmethod | |||||
| def retrieve(cls, retrival_method: str, dataset_id: str, query: str, | |||||
| top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None): | |||||
| all_documents = [] | |||||
| threads = [] | |||||
| # retrieval_model source with keyword | |||||
| if retrival_method == 'keyword_search': | |||||
| keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={ | |||||
| 'flask_app': current_app._get_current_object(), | |||||
| 'dataset_id': dataset_id, | |||||
| 'query': query, | |||||
| 'top_k': top_k | |||||
| }) | |||||
| threads.append(keyword_thread) | |||||
| keyword_thread.start() | |||||
| # retrieval_model source with semantic | |||||
| if retrival_method == 'semantic_search' or retrival_method == 'hybrid_search': | |||||
| embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ | |||||
| 'flask_app': current_app._get_current_object(), | |||||
| 'dataset_id': dataset_id, | |||||
| 'query': query, | |||||
| 'top_k': top_k, | |||||
| 'score_threshold': score_threshold, | |||||
| 'reranking_model': reranking_model, | |||||
| 'all_documents': all_documents, | |||||
| 'retrival_method': retrival_method | |||||
| }) | |||||
| threads.append(embedding_thread) | |||||
| embedding_thread.start() | |||||
| # retrieval source with full text | |||||
| if retrival_method == 'full_text_search' or retrival_method == 'hybrid_search': | |||||
| full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ | |||||
| 'flask_app': current_app._get_current_object(), | |||||
| 'dataset_id': dataset_id, | |||||
| 'query': query, | |||||
| 'retrival_method': retrival_method, | |||||
| 'score_threshold': score_threshold, | |||||
| 'top_k': top_k, | |||||
| 'reranking_model': reranking_model, | |||||
| 'all_documents': all_documents | |||||
| }) | |||||
| threads.append(full_text_index_thread) | |||||
| full_text_index_thread.start() | |||||
| for thread in threads: | |||||
| thread.join() | |||||
| if retrival_method == 'hybrid_search': | |||||
| data_post_processor = DataPostProcessor(str(current_user.current_tenant_id), reranking_model, False) | |||||
| all_documents = data_post_processor.invoke( | |||||
| query=query, | |||||
| documents=all_documents, | |||||
| score_threshold=score_threshold, | |||||
| top_n=top_k | |||||
| ) | |||||
| return all_documents | |||||
| @classmethod | |||||
| def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str, | |||||
| top_k: int, all_documents: list): | |||||
| with flask_app.app_context(): | |||||
| dataset = db.session.query(Dataset).filter( | |||||
| Dataset.id == dataset_id | |||||
| ).first() | |||||
| keyword = Keyword( | |||||
| dataset=dataset | |||||
| ) | |||||
| documents = keyword.search( | |||||
| query, | |||||
| k=top_k | |||||
| ) | |||||
| all_documents.extend(documents) | |||||
| @classmethod | |||||
| def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, | |||||
| top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], | |||||
| all_documents: list, retrival_method: str): | |||||
| with flask_app.app_context(): | |||||
| dataset = db.session.query(Dataset).filter( | |||||
| Dataset.id == dataset_id | |||||
| ).first() | |||||
| vector = Vector( | |||||
| dataset=dataset | |||||
| ) | |||||
| documents = vector.search_by_vector( | |||||
| query, | |||||
| search_type='similarity_score_threshold', | |||||
| k=top_k, | |||||
| score_threshold=score_threshold, | |||||
| filter={ | |||||
| 'group_id': [dataset.id] | |||||
| } | |||||
| ) | |||||
| if documents: | |||||
| if reranking_model and retrival_method == 'semantic_search': | |||||
| data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) | |||||
| all_documents.extend(data_post_processor.invoke( | |||||
| query=query, | |||||
| documents=documents, | |||||
| score_threshold=score_threshold, | |||||
| top_n=len(documents) | |||||
| )) | |||||
| else: | |||||
| all_documents.extend(documents) | |||||
| @classmethod | |||||
| def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, | |||||
| top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], | |||||
| all_documents: list, retrival_method: str): | |||||
| with flask_app.app_context(): | |||||
| dataset = db.session.query(Dataset).filter( | |||||
| Dataset.id == dataset_id | |||||
| ).first() | |||||
| vector_processor = Vector( | |||||
| dataset=dataset, | |||||
| ) | |||||
| documents = vector_processor.search_by_full_text( | |||||
| query, | |||||
| top_k=top_k | |||||
| ) | |||||
| if documents: | |||||
| if reranking_model and retrival_method == 'full_text_search': | |||||
| data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) | |||||
| all_documents.extend(data_post_processor.invoke( | |||||
| query=query, | |||||
| documents=documents, | |||||
| score_threshold=score_threshold, | |||||
| top_n=len(documents) | |||||
| )) | |||||
| else: | |||||
| all_documents.extend(documents) |
| from enum import Enum | |||||
| class Field(Enum): | |||||
| CONTENT_KEY = "page_content" | |||||
| METADATA_KEY = "metadata" | |||||
| GROUP_KEY = "group_id" | |||||
| VECTOR = "vector" | |||||
| TEXT_KEY = "text" | |||||
| PRIMARY_KEY = " id" |
| import logging | |||||
| from typing import Any, Optional | |||||
| from uuid import uuid4 | |||||
| from pydantic import BaseModel, root_validator | |||||
| from pymilvus import MilvusClient, MilvusException, connections | |||||
| from core.rag.datasource.vdb.field import Field | |||||
| from core.rag.datasource.vdb.vector_base import BaseVector | |||||
| from core.rag.models.document import Document | |||||
| logger = logging.getLogger(__name__) | |||||
| class MilvusConfig(BaseModel): | |||||
| host: str | |||||
| port: int | |||||
| user: str | |||||
| password: str | |||||
| secure: bool = False | |||||
| batch_size: int = 100 | |||||
| @root_validator() | |||||
| def validate_config(cls, values: dict) -> dict: | |||||
| if not values['host']: | |||||
| raise ValueError("config MILVUS_HOST is required") | |||||
| if not values['port']: | |||||
| raise ValueError("config MILVUS_PORT is required") | |||||
| if not values['user']: | |||||
| raise ValueError("config MILVUS_USER is required") | |||||
| if not values['password']: | |||||
| raise ValueError("config MILVUS_PASSWORD is required") | |||||
| return values | |||||
| def to_milvus_params(self): | |||||
| return { | |||||
| 'host': self.host, | |||||
| 'port': self.port, | |||||
| 'user': self.user, | |||||
| 'password': self.password, | |||||
| 'secure': self.secure | |||||
| } | |||||
| class MilvusVector(BaseVector): | |||||
| def __init__(self, collection_name: str, config: MilvusConfig): | |||||
| super().__init__(collection_name) | |||||
| self._client_config = config | |||||
| self._client = self._init_client(config) | |||||
| self._consistency_level = 'Session' | |||||
| self._fields = [] | |||||
| def get_type(self) -> str: | |||||
| return 'milvus' | |||||
| def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | |||||
| index_params = { | |||||
| 'metric_type': 'IP', | |||||
| 'index_type': "HNSW", | |||||
| 'params': {"M": 8, "efConstruction": 64} | |||||
| } | |||||
| metadatas = [d.metadata for d in texts] | |||||
| # Grab the existing collection if it exists | |||||
| from pymilvus import utility | |||||
| alias = uuid4().hex | |||||
| if self._client_config.secure: | |||||
| uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) | |||||
| else: | |||||
| uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) | |||||
| connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password) | |||||
| if not utility.has_collection(self._collection_name, using=alias): | |||||
| self.create_collection(embeddings, metadatas, index_params) | |||||
| self.add_texts(texts, embeddings) | |||||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | |||||
| insert_dict_list = [] | |||||
| for i in range(len(documents)): | |||||
| insert_dict = { | |||||
| Field.CONTENT_KEY.value: documents[i].page_content, | |||||
| Field.VECTOR.value: embeddings[i], | |||||
| Field.METADATA_KEY.value: documents[i].metadata | |||||
| } | |||||
| insert_dict_list.append(insert_dict) | |||||
| # Total insert count | |||||
| total_count = len(insert_dict_list) | |||||
| pks: list[str] = [] | |||||
| for i in range(0, total_count, 1000): | |||||
| batch_insert_list = insert_dict_list[i:i + 1000] | |||||
| # Insert into the collection. | |||||
| try: | |||||
| ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list) | |||||
| pks.extend(ids) | |||||
| except MilvusException as e: | |||||
| logger.error( | |||||
| "Failed to insert batch starting at entity: %s/%s", i, total_count | |||||
| ) | |||||
| raise e | |||||
| return pks | |||||
| def delete_by_document_id(self, document_id: str): | |||||
| ids = self.get_ids_by_metadata_field('document_id', document_id) | |||||
| if ids: | |||||
| self._client.delete(collection_name=self._collection_name, pks=ids) | |||||
| def get_ids_by_metadata_field(self, key: str, value: str): | |||||
| result = self._client.query(collection_name=self._collection_name, | |||||
| filter=f'metadata["{key}"] == "{value}"', | |||||
| output_fields=["id"]) | |||||
| if result: | |||||
| return [item["id"] for item in result] | |||||
| else: | |||||
| return None | |||||
| def delete_by_metadata_field(self, key: str, value: str): | |||||
| ids = self.get_ids_by_metadata_field(key, value) | |||||
| if ids: | |||||
| self._client.delete(collection_name=self._collection_name, pks=ids) | |||||
| def delete_by_ids(self, doc_ids: list[str]) -> None: | |||||
| self._client.delete(collection_name=self._collection_name, pks=doc_ids) | |||||
| def delete(self) -> None: | |||||
| from pymilvus import utility | |||||
| utility.drop_collection(self._collection_name, None) | |||||
| def text_exists(self, id: str) -> bool: | |||||
| result = self._client.query(collection_name=self._collection_name, | |||||
| filter=f'metadata["doc_id"] == "{id}"', | |||||
| output_fields=["id"]) | |||||
| return len(result) > 0 | |||||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||||
| # Set search parameters. | |||||
| results = self._client.search(collection_name=self._collection_name, | |||||
| data=[query_vector], | |||||
| limit=kwargs.get('top_k', 4), | |||||
| output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], | |||||
| ) | |||||
| # Organize results. | |||||
| docs = [] | |||||
| for result in results[0]: | |||||
| metadata = result['entity'].get(Field.METADATA_KEY.value) | |||||
| metadata['score'] = result['distance'] | |||||
| score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 | |||||
| if result['distance'] > score_threshold: | |||||
| doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value), | |||||
| metadata=metadata) | |||||
| docs.append(doc) | |||||
| return docs | |||||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||||
| # milvus/zilliz doesn't support bm25 search | |||||
| return [] | |||||
| def create_collection( | |||||
| self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None | |||||
| ) -> str: | |||||
| from pymilvus import CollectionSchema, DataType, FieldSchema | |||||
| from pymilvus.orm.types import infer_dtype_bydata | |||||
| # Determine embedding dim | |||||
| dim = len(embeddings[0]) | |||||
| fields = [] | |||||
| if metadatas: | |||||
| fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) | |||||
| # Create the text field | |||||
| fields.append( | |||||
| FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535) | |||||
| ) | |||||
| # Create the primary key field | |||||
| fields.append( | |||||
| FieldSchema( | |||||
| Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True | |||||
| ) | |||||
| ) | |||||
| # Create the vector field, supports binary or float vectors | |||||
| fields.append( | |||||
| FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim) | |||||
| ) | |||||
| # Create the schema for the collection | |||||
| schema = CollectionSchema(fields) | |||||
| for x in schema.fields: | |||||
| self._fields.append(x.name) | |||||
| # Since primary field is auto-id, no need to track it | |||||
| self._fields.remove(Field.PRIMARY_KEY.value) | |||||
| # Create the collection | |||||
| collection_name = self._collection_name | |||||
| self._client.create_collection_with_schema(collection_name=collection_name, | |||||
| schema=schema, index_param=index_params, | |||||
| consistency_level=self._consistency_level) | |||||
| return collection_name | |||||
| def _init_client(self, config) -> MilvusClient: | |||||
| if config.secure: | |||||
| uri = "https://" + str(config.host) + ":" + str(config.port) | |||||
| else: | |||||
| uri = "http://" + str(config.host) + ":" + str(config.port) | |||||
| client = MilvusClient(uri=uri, user=config.user, password=config.password) | |||||
| return client |
| import os | |||||
| import uuid | |||||
| from collections.abc import Generator, Iterable, Sequence | |||||
| from itertools import islice | |||||
| from typing import TYPE_CHECKING, Any, Optional, Union, cast | |||||
| import qdrant_client | |||||
| from pydantic import BaseModel | |||||
| from qdrant_client.http import models as rest | |||||
| from qdrant_client.http.models import ( | |||||
| FilterSelector, | |||||
| HnswConfigDiff, | |||||
| PayloadSchemaType, | |||||
| TextIndexParams, | |||||
| TextIndexType, | |||||
| TokenizerType, | |||||
| ) | |||||
| from qdrant_client.local.qdrant_local import QdrantLocal | |||||
| from core.rag.datasource.vdb.field import Field | |||||
| from core.rag.datasource.vdb.vector_base import BaseVector | |||||
| from core.rag.models.document import Document | |||||
| if TYPE_CHECKING: | |||||
| from qdrant_client import grpc # noqa | |||||
| from qdrant_client.conversions import common_types | |||||
| from qdrant_client.http import models as rest | |||||
| DictFilter = dict[str, Union[str, int, bool, dict, list]] | |||||
| MetadataFilter = Union[DictFilter, common_types.Filter] | |||||
| class QdrantConfig(BaseModel): | |||||
| endpoint: str | |||||
| api_key: Optional[str] | |||||
| timeout: float = 20 | |||||
| root_path: Optional[str] | |||||
| def to_qdrant_params(self): | |||||
| if self.endpoint and self.endpoint.startswith('path:'): | |||||
| path = self.endpoint.replace('path:', '') | |||||
| if not os.path.isabs(path): | |||||
| path = os.path.join(self.root_path, path) | |||||
| return { | |||||
| 'path': path | |||||
| } | |||||
| else: | |||||
| return { | |||||
| 'url': self.endpoint, | |||||
| 'api_key': self.api_key, | |||||
| 'timeout': self.timeout | |||||
| } | |||||
| class QdrantVector(BaseVector): | |||||
| def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = 'Cosine'): | |||||
| super().__init__(collection_name) | |||||
| self._client_config = config | |||||
| self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params()) | |||||
| self._distance_func = distance_func.upper() | |||||
| self._group_id = group_id | |||||
| def get_type(self) -> str: | |||||
| return 'qdrant' | |||||
| def to_index_struct(self) -> dict: | |||||
| return { | |||||
| "type": self.get_type(), | |||||
| "vector_store": {"class_prefix": self._collection_name} | |||||
| } | |||||
| def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | |||||
| if texts: | |||||
| # get embedding vector size | |||||
| vector_size = len(embeddings[0]) | |||||
| # get collection name | |||||
| collection_name = self._collection_name | |||||
| collection_name = collection_name or uuid.uuid4().hex | |||||
| all_collection_name = [] | |||||
| collections_response = self._client.get_collections() | |||||
| collection_list = collections_response.collections | |||||
| for collection in collection_list: | |||||
| all_collection_name.append(collection.name) | |||||
| if collection_name not in all_collection_name: | |||||
| # create collection | |||||
| self.create_collection(collection_name, vector_size) | |||||
| self.add_texts(texts, embeddings, **kwargs) | |||||
| def create_collection(self, collection_name: str, vector_size: int): | |||||
| from qdrant_client.http import models as rest | |||||
| vectors_config = rest.VectorParams( | |||||
| size=vector_size, | |||||
| distance=rest.Distance[self._distance_func], | |||||
| ) | |||||
| hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, | |||||
| max_indexing_threads=0, on_disk=False) | |||||
| self._client.recreate_collection( | |||||
| collection_name=collection_name, | |||||
| vectors_config=vectors_config, | |||||
| hnsw_config=hnsw_config, | |||||
| timeout=int(self._client_config.timeout), | |||||
| ) | |||||
| # create payload index | |||||
| self._client.create_payload_index(collection_name, Field.GROUP_KEY.value, | |||||
| field_schema=PayloadSchemaType.KEYWORD, | |||||
| field_type=PayloadSchemaType.KEYWORD) | |||||
| # creat full text index | |||||
| text_index_params = TextIndexParams( | |||||
| type=TextIndexType.TEXT, | |||||
| tokenizer=TokenizerType.MULTILINGUAL, | |||||
| min_token_len=2, | |||||
| max_token_len=20, | |||||
| lowercase=True | |||||
| ) | |||||
| self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value, | |||||
| field_schema=text_index_params) | |||||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | |||||
| uuids = self._get_uuids(documents) | |||||
| texts = [d.page_content for d in documents] | |||||
| metadatas = [d.metadata for d in documents] | |||||
| added_ids = [] | |||||
| for batch_ids, points in self._generate_rest_batches( | |||||
| texts, embeddings, metadatas, uuids, 64, self._group_id | |||||
| ): | |||||
| self._client.upsert( | |||||
| collection_name=self._collection_name, points=points | |||||
| ) | |||||
| added_ids.extend(batch_ids) | |||||
| return added_ids | |||||
| def _generate_rest_batches( | |||||
| self, | |||||
| texts: Iterable[str], | |||||
| embeddings: list[list[float]], | |||||
| metadatas: Optional[list[dict]] = None, | |||||
| ids: Optional[Sequence[str]] = None, | |||||
| batch_size: int = 64, | |||||
| group_id: Optional[str] = None, | |||||
| ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]: | |||||
| from qdrant_client.http import models as rest | |||||
| texts_iterator = iter(texts) | |||||
| embeddings_iterator = iter(embeddings) | |||||
| metadatas_iterator = iter(metadatas or []) | |||||
| ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)]) | |||||
| while batch_texts := list(islice(texts_iterator, batch_size)): | |||||
| # Take the corresponding metadata and id for each text in a batch | |||||
| batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None | |||||
| batch_ids = list(islice(ids_iterator, batch_size)) | |||||
| # Generate the embeddings for all the texts in a batch | |||||
| batch_embeddings = list(islice(embeddings_iterator, batch_size)) | |||||
| points = [ | |||||
| rest.PointStruct( | |||||
| id=point_id, | |||||
| vector=vector, | |||||
| payload=payload, | |||||
| ) | |||||
| for point_id, vector, payload in zip( | |||||
| batch_ids, | |||||
| batch_embeddings, | |||||
| self._build_payloads( | |||||
| batch_texts, | |||||
| batch_metadatas, | |||||
| Field.CONTENT_KEY.value, | |||||
| Field.METADATA_KEY.value, | |||||
| group_id, | |||||
| Field.GROUP_KEY.value, | |||||
| ), | |||||
| ) | |||||
| ] | |||||
| yield batch_ids, points | |||||
| @classmethod | |||||
| def _build_payloads( | |||||
| cls, | |||||
| texts: Iterable[str], | |||||
| metadatas: Optional[list[dict]], | |||||
| content_payload_key: str, | |||||
| metadata_payload_key: str, | |||||
| group_id: str, | |||||
| group_payload_key: str | |||||
| ) -> list[dict]: | |||||
| payloads = [] | |||||
| for i, text in enumerate(texts): | |||||
| if text is None: | |||||
| raise ValueError( | |||||
| "At least one of the texts is None. Please remove it before " | |||||
| "calling .from_texts or .add_texts on Qdrant instance." | |||||
| ) | |||||
| metadata = metadatas[i] if metadatas is not None else None | |||||
| payloads.append( | |||||
| { | |||||
| content_payload_key: text, | |||||
| metadata_payload_key: metadata, | |||||
| group_payload_key: group_id | |||||
| } | |||||
| ) | |||||
| return payloads | |||||
| def delete_by_metadata_field(self, key: str, value: str): | |||||
| from qdrant_client.http import models | |||||
| filter = models.Filter( | |||||
| must=[ | |||||
| models.FieldCondition( | |||||
| key=f"metadata.{key}", | |||||
| match=models.MatchValue(value=value), | |||||
| ), | |||||
| ], | |||||
| ) | |||||
| self._reload_if_needed() | |||||
| self._client.delete( | |||||
| collection_name=self._collection_name, | |||||
| points_selector=FilterSelector( | |||||
| filter=filter | |||||
| ), | |||||
| ) | |||||
| def delete(self): | |||||
| from qdrant_client.http import models | |||||
| filter = models.Filter( | |||||
| must=[ | |||||
| models.FieldCondition( | |||||
| key="group_id", | |||||
| match=models.MatchValue(value=self._group_id), | |||||
| ), | |||||
| ], | |||||
| ) | |||||
| self._client.delete( | |||||
| collection_name=self._collection_name, | |||||
| points_selector=FilterSelector( | |||||
| filter=filter | |||||
| ), | |||||
| ) | |||||
| def delete_by_ids(self, ids: list[str]) -> None: | |||||
| from qdrant_client.http import models | |||||
| for node_id in ids: | |||||
| filter = models.Filter( | |||||
| must=[ | |||||
| models.FieldCondition( | |||||
| key="metadata.doc_id", | |||||
| match=models.MatchValue(value=node_id), | |||||
| ), | |||||
| ], | |||||
| ) | |||||
| self._client.delete( | |||||
| collection_name=self._collection_name, | |||||
| points_selector=FilterSelector( | |||||
| filter=filter | |||||
| ), | |||||
| ) | |||||
| def text_exists(self, id: str) -> bool: | |||||
| response = self._client.retrieve( | |||||
| collection_name=self._collection_name, | |||||
| ids=[id] | |||||
| ) | |||||
| return len(response) > 0 | |||||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||||
| from qdrant_client.http import models | |||||
| filter = models.Filter( | |||||
| must=[ | |||||
| models.FieldCondition( | |||||
| key="group_id", | |||||
| match=models.MatchValue(value=self._group_id), | |||||
| ), | |||||
| ], | |||||
| ) | |||||
| results = self._client.search( | |||||
| collection_name=self._collection_name, | |||||
| query_vector=query_vector, | |||||
| query_filter=filter, | |||||
| limit=kwargs.get("top_k", 4), | |||||
| with_payload=True, | |||||
| with_vectors=True, | |||||
| score_threshold=kwargs.get("score_threshold", .0) | |||||
| ) | |||||
| docs = [] | |||||
| for result in results: | |||||
| metadata = result.payload.get(Field.METADATA_KEY.value) or {} | |||||
| # duplicate check score threshold | |||||
| score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 | |||||
| if result.score > score_threshold: | |||||
| metadata['score'] = result.score | |||||
| doc = Document( | |||||
| page_content=result.payload.get(Field.CONTENT_KEY.value), | |||||
| metadata=metadata, | |||||
| ) | |||||
| docs.append(doc) | |||||
| return docs | |||||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||||
| """Return docs most similar by bm25. | |||||
| Returns: | |||||
| List of documents most similar to the query text and distance for each. | |||||
| """ | |||||
| from qdrant_client.http import models | |||||
| scroll_filter = models.Filter( | |||||
| must=[ | |||||
| models.FieldCondition( | |||||
| key="group_id", | |||||
| match=models.MatchValue(value=self._group_id), | |||||
| ), | |||||
| models.FieldCondition( | |||||
| key="page_content", | |||||
| match=models.MatchText(text=query), | |||||
| ) | |||||
| ] | |||||
| ) | |||||
| response = self._client.scroll( | |||||
| collection_name=self._collection_name, | |||||
| scroll_filter=scroll_filter, | |||||
| limit=kwargs.get('top_k', 2), | |||||
| with_payload=True, | |||||
| with_vectors=True | |||||
| ) | |||||
| results = response[0] | |||||
| documents = [] | |||||
| for result in results: | |||||
| if result: | |||||
| documents.append(self._document_from_scored_point( | |||||
| result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value | |||||
| )) | |||||
| return documents | |||||
| def _reload_if_needed(self): | |||||
| if isinstance(self._client, QdrantLocal): | |||||
| self._client = cast(QdrantLocal, self._client) | |||||
| self._client._load() | |||||
| @classmethod | |||||
| def _document_from_scored_point( | |||||
| cls, | |||||
| scored_point: Any, | |||||
| content_payload_key: str, | |||||
| metadata_payload_key: str, | |||||
| ) -> Document: | |||||
| return Document( | |||||
| page_content=scored_point.payload.get(content_payload_key), | |||||
| metadata=scored_point.payload.get(metadata_payload_key) or {}, | |||||
| ) |
| from __future__ import annotations | |||||
| from abc import ABC, abstractmethod | |||||
| from typing import Any | |||||
| from core.rag.models.document import Document | |||||
| class BaseVector(ABC): | |||||
| def __init__(self, collection_name: str): | |||||
| self._collection_name = collection_name | |||||
| @abstractmethod | |||||
| def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def text_exists(self, id: str) -> bool: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def delete_by_ids(self, ids: list[str]) -> None: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def delete_by_metadata_field(self, key: str, value: str) -> None: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def search_by_vector( | |||||
| self, | |||||
| query_vector: list[float], | |||||
| **kwargs: Any | |||||
| ) -> list[Document]: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def search_by_full_text( | |||||
| self, query: str, | |||||
| **kwargs: Any | |||||
| ) -> list[Document]: | |||||
| raise NotImplementedError | |||||
| def delete(self) -> None: | |||||
| raise NotImplementedError | |||||
| def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: | |||||
| for text in texts: | |||||
| doc_id = text.metadata['doc_id'] | |||||
| exists_duplicate_node = self.text_exists(doc_id) | |||||
| if exists_duplicate_node: | |||||
| texts.remove(text) | |||||
| return texts | |||||
| def _get_uuids(self, texts: list[Document]) -> list[str]: | |||||
| return [text.metadata['doc_id'] for text in texts] |
| from typing import Any, cast | |||||
| from flask import current_app | |||||
| from core.embedding.cached_embedding import CacheEmbedding | |||||
| from core.model_manager import ModelManager | |||||
| from core.model_runtime.entities.model_entities import ModelType | |||||
| from core.rag.datasource.entity.embedding import Embeddings | |||||
| from core.rag.datasource.vdb.vector_base import BaseVector | |||||
| from core.rag.models.document import Document | |||||
| from extensions.ext_database import db | |||||
| from models.dataset import Dataset, DatasetCollectionBinding | |||||
| class Vector: | |||||
| def __init__(self, dataset: Dataset, attributes: list = None): | |||||
| if attributes is None: | |||||
| attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] | |||||
| self._dataset = dataset | |||||
| self._embeddings = self._get_embeddings() | |||||
| self._attributes = attributes | |||||
| self._vector_processor = self._init_vector() | |||||
| def _init_vector(self) -> BaseVector: | |||||
| config = cast(dict, current_app.config) | |||||
| vector_type = config.get('VECTOR_STORE') | |||||
| if self._dataset.index_struct_dict: | |||||
| vector_type = self._dataset.index_struct_dict['type'] | |||||
| if not vector_type: | |||||
| raise ValueError("Vector store must be specified.") | |||||
| if vector_type == "weaviate": | |||||
| from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector | |||||
| if self._dataset.index_struct_dict: | |||||
| class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| collection_name = class_prefix | |||||
| else: | |||||
| dataset_id = self._dataset.id | |||||
| collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' | |||||
| return WeaviateVector( | |||||
| collection_name=collection_name, | |||||
| config=WeaviateConfig( | |||||
| endpoint=config.get('WEAVIATE_ENDPOINT'), | |||||
| api_key=config.get('WEAVIATE_API_KEY'), | |||||
| batch_size=int(config.get('WEAVIATE_BATCH_SIZE')) | |||||
| ), | |||||
| attributes=self._attributes | |||||
| ) | |||||
| elif vector_type == "qdrant": | |||||
| from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector | |||||
| if self._dataset.collection_binding_id: | |||||
| dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ | |||||
| filter(DatasetCollectionBinding.id == self._dataset.collection_binding_id). \ | |||||
| one_or_none() | |||||
| if dataset_collection_binding: | |||||
| collection_name = dataset_collection_binding.collection_name | |||||
| else: | |||||
| raise ValueError('Dataset Collection Bindings is not exist!') | |||||
| else: | |||||
| if self._dataset.index_struct_dict: | |||||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| collection_name = class_prefix | |||||
| else: | |||||
| dataset_id = self._dataset.id | |||||
| collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' | |||||
| return QdrantVector( | |||||
| collection_name=collection_name, | |||||
| group_id=self._dataset.id, | |||||
| config=QdrantConfig( | |||||
| endpoint=config.get('QDRANT_URL'), | |||||
| api_key=config.get('QDRANT_API_KEY'), | |||||
| root_path=current_app.root_path, | |||||
| timeout=config.get('QDRANT_CLIENT_TIMEOUT') | |||||
| ) | |||||
| ) | |||||
| elif vector_type == "milvus": | |||||
| from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector | |||||
| if self._dataset.index_struct_dict: | |||||
| class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| collection_name = class_prefix | |||||
| else: | |||||
| dataset_id = self._dataset.id | |||||
| collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' | |||||
| return MilvusVector( | |||||
| collection_name=collection_name, | |||||
| config=MilvusConfig( | |||||
| host=config.get('MILVUS_HOST'), | |||||
| port=config.get('MILVUS_PORT'), | |||||
| user=config.get('MILVUS_USER'), | |||||
| password=config.get('MILVUS_PASSWORD'), | |||||
| secure=config.get('MILVUS_SECURE'), | |||||
| ) | |||||
| ) | |||||
| else: | |||||
| raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") | |||||
| def create(self, texts: list = None, **kwargs): | |||||
| if texts: | |||||
| embeddings = self._embeddings.embed_documents([document.page_content for document in texts]) | |||||
| self._vector_processor.create( | |||||
| texts=texts, | |||||
| embeddings=embeddings, | |||||
| **kwargs | |||||
| ) | |||||
| def add_texts(self, documents: list[Document], **kwargs): | |||||
| if kwargs.get('duplicate_check', False): | |||||
| documents = self._filter_duplicate_texts(documents) | |||||
| embeddings = self._embeddings.embed_documents([document.page_content for document in documents]) | |||||
| self._vector_processor.add_texts( | |||||
| documents=documents, | |||||
| embeddings=embeddings, | |||||
| **kwargs | |||||
| ) | |||||
| def text_exists(self, id: str) -> bool: | |||||
| return self._vector_processor.text_exists(id) | |||||
| def delete_by_ids(self, ids: list[str]) -> None: | |||||
| self._vector_processor.delete_by_ids(ids) | |||||
| def delete_by_metadata_field(self, key: str, value: str) -> None: | |||||
| self._vector_processor.delete_by_metadata_field(key, value) | |||||
| def search_by_vector( | |||||
| self, query: str, | |||||
| **kwargs: Any | |||||
| ) -> list[Document]: | |||||
| query_vector = self._embeddings.embed_query(query) | |||||
| return self._vector_processor.search_by_vector(query_vector, **kwargs) | |||||
| def search_by_full_text( | |||||
| self, query: str, | |||||
| **kwargs: Any | |||||
| ) -> list[Document]: | |||||
| return self._vector_processor.search_by_full_text(query, **kwargs) | |||||
| def delete(self) -> None: | |||||
| self._vector_processor.delete() | |||||
| def _get_embeddings(self) -> Embeddings: | |||||
| model_manager = ModelManager() | |||||
| embedding_model = model_manager.get_model_instance( | |||||
| tenant_id=self._dataset.tenant_id, | |||||
| provider=self._dataset.embedding_model_provider, | |||||
| model_type=ModelType.TEXT_EMBEDDING, | |||||
| model=self._dataset.embedding_model | |||||
| ) | |||||
| return CacheEmbedding(embedding_model) | |||||
| def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: | |||||
| for text in texts: | |||||
| doc_id = text.metadata['doc_id'] | |||||
| exists_duplicate_node = self.text_exists(doc_id) | |||||
| if exists_duplicate_node: | |||||
| texts.remove(text) | |||||
| return texts | |||||
| def __getattr__(self, name): | |||||
| if self._vector_processor is not None: | |||||
| method = getattr(self._vector_processor, name) | |||||
| if callable(method): | |||||
| return method | |||||
| raise AttributeError(f"'vector_processor' object has no attribute '{name}'") |
| import datetime | |||||
| from typing import Any, Optional | |||||
| import requests | |||||
| import weaviate | |||||
| from pydantic import BaseModel, root_validator | |||||
| from core.rag.datasource.vdb.field import Field | |||||
| from core.rag.datasource.vdb.vector_base import BaseVector | |||||
| from core.rag.models.document import Document | |||||
| from models.dataset import Dataset | |||||
| class WeaviateConfig(BaseModel): | |||||
| endpoint: str | |||||
| api_key: Optional[str] | |||||
| batch_size: int = 100 | |||||
| @root_validator() | |||||
| def validate_config(cls, values: dict) -> dict: | |||||
| if not values['endpoint']: | |||||
| raise ValueError("config WEAVIATE_ENDPOINT is required") | |||||
| return values | |||||
| class WeaviateVector(BaseVector): | |||||
| def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list): | |||||
| super().__init__(collection_name) | |||||
| self._client = self._init_client(config) | |||||
| self._attributes = attributes | |||||
| def _init_client(self, config: WeaviateConfig) -> weaviate.Client: | |||||
| auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key) | |||||
| weaviate.connect.connection.has_grpc = False | |||||
| try: | |||||
| client = weaviate.Client( | |||||
| url=config.endpoint, | |||||
| auth_client_secret=auth_config, | |||||
| timeout_config=(5, 60), | |||||
| startup_period=None | |||||
| ) | |||||
| except requests.exceptions.ConnectionError: | |||||
| raise ConnectionError("Vector database connection error") | |||||
| client.batch.configure( | |||||
| # `batch_size` takes an `int` value to enable auto-batching | |||||
| # (`None` is used for manual batching) | |||||
| batch_size=config.batch_size, | |||||
| # dynamically update the `batch_size` based on import speed | |||||
| dynamic=True, | |||||
| # `timeout_retries` takes an `int` value to retry on time outs | |||||
| timeout_retries=3, | |||||
| ) | |||||
| return client | |||||
| def get_type(self) -> str: | |||||
| return 'weaviate' | |||||
| def get_collection_name(self, dataset: Dataset) -> str: | |||||
| if dataset.index_struct_dict: | |||||
| class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| if not class_prefix.endswith('_Node'): | |||||
| # original class_prefix | |||||
| class_prefix += '_Node' | |||||
| return class_prefix | |||||
| dataset_id = dataset.id | |||||
| return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' | |||||
| def to_index_struct(self) -> dict: | |||||
| return { | |||||
| "type": self.get_type(), | |||||
| "vector_store": {"class_prefix": self._collection_name} | |||||
| } | |||||
| def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | |||||
| schema = self._default_schema(self._collection_name) | |||||
| # check whether the index already exists | |||||
| if not self._client.schema.contains(schema): | |||||
| # create collection | |||||
| self._client.schema.create_class(schema) | |||||
| # create vector | |||||
| self.add_texts(texts, embeddings) | |||||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | |||||
| uuids = self._get_uuids(documents) | |||||
| texts = [d.page_content for d in documents] | |||||
| metadatas = [d.metadata for d in documents] | |||||
| ids = [] | |||||
| with self._client.batch as batch: | |||||
| for i, text in enumerate(texts): | |||||
| data_properties = {Field.TEXT_KEY.value: text} | |||||
| if metadatas is not None: | |||||
| for key, val in metadatas[i].items(): | |||||
| data_properties[key] = self._json_serializable(val) | |||||
| batch.add_data_object( | |||||
| data_object=data_properties, | |||||
| class_name=self._collection_name, | |||||
| uuid=uuids[i], | |||||
| vector=embeddings[i] if embeddings else None, | |||||
| ) | |||||
| ids.append(uuids[i]) | |||||
| return ids | |||||
| def delete_by_metadata_field(self, key: str, value: str): | |||||
| where_filter = { | |||||
| "operator": "Equal", | |||||
| "path": [key], | |||||
| "valueText": value | |||||
| } | |||||
| self._client.batch.delete_objects( | |||||
| class_name=self._collection_name, | |||||
| where=where_filter, | |||||
| output='minimal' | |||||
| ) | |||||
| def delete(self): | |||||
| self._client.schema.delete_class(self._collection_name) | |||||
| def text_exists(self, id: str) -> bool: | |||||
| collection_name = self._collection_name | |||||
| result = self._client.query.get(collection_name).with_additional(["id"]).with_where({ | |||||
| "path": ["doc_id"], | |||||
| "operator": "Equal", | |||||
| "valueText": id, | |||||
| }).with_limit(1).do() | |||||
| if "errors" in result: | |||||
| raise ValueError(f"Error during query: {result['errors']}") | |||||
| entries = result["data"]["Get"][collection_name] | |||||
| if len(entries) == 0: | |||||
| return False | |||||
| return True | |||||
| def delete_by_ids(self, ids: list[str]) -> None: | |||||
| self._client.data_object.delete( | |||||
| ids, | |||||
| class_name=self._collection_name | |||||
| ) | |||||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||||
| """Look up similar documents by embedding vector in Weaviate.""" | |||||
| collection_name = self._collection_name | |||||
| properties = self._attributes | |||||
| properties.append(Field.TEXT_KEY.value) | |||||
| query_obj = self._client.query.get(collection_name, properties) | |||||
| vector = {"vector": query_vector} | |||||
| if kwargs.get("where_filter"): | |||||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||||
| result = ( | |||||
| query_obj.with_near_vector(vector) | |||||
| .with_limit(kwargs.get("top_k", 4)) | |||||
| .with_additional(["vector", "distance"]) | |||||
| .do() | |||||
| ) | |||||
| if "errors" in result: | |||||
| raise ValueError(f"Error during query: {result['errors']}") | |||||
| docs_and_scores = [] | |||||
| for res in result["data"]["Get"][collection_name]: | |||||
| text = res.pop(Field.TEXT_KEY.value) | |||||
| score = 1 - res["_additional"]["distance"] | |||||
| docs_and_scores.append((Document(page_content=text, metadata=res), score)) | |||||
| docs = [] | |||||
| for doc, score in docs_and_scores: | |||||
| score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 | |||||
| # check score threshold | |||||
| if score > score_threshold: | |||||
| doc.metadata['score'] = score | |||||
| docs.append(doc) | |||||
| return docs | |||||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||||
| """Return docs using BM25F. | |||||
| Args: | |||||
| query: Text to look up documents similar to. | |||||
| k: Number of Documents to return. Defaults to 4. | |||||
| Returns: | |||||
| List of Documents most similar to the query. | |||||
| """ | |||||
| collection_name = self._collection_name | |||||
| content: dict[str, Any] = {"concepts": [query]} | |||||
| properties = self._attributes | |||||
| properties.append(Field.TEXT_KEY.value) | |||||
| if kwargs.get("search_distance"): | |||||
| content["certainty"] = kwargs.get("search_distance") | |||||
| query_obj = self._client.query.get(collection_name, properties) | |||||
| if kwargs.get("where_filter"): | |||||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||||
| if kwargs.get("additional"): | |||||
| query_obj = query_obj.with_additional(kwargs.get("additional")) | |||||
| properties = ['text'] | |||||
| result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do() | |||||
| if "errors" in result: | |||||
| raise ValueError(f"Error during query: {result['errors']}") | |||||
| docs = [] | |||||
| for res in result["data"]["Get"][collection_name]: | |||||
| text = res.pop(Field.TEXT_KEY.value) | |||||
| docs.append(Document(page_content=text, metadata=res)) | |||||
| return docs | |||||
| def _default_schema(self, index_name: str) -> dict: | |||||
| return { | |||||
| "class": index_name, | |||||
| "properties": [ | |||||
| { | |||||
| "name": "text", | |||||
| "dataType": ["text"], | |||||
| } | |||||
| ], | |||||
| } | |||||
| def _json_serializable(self, value: Any) -> Any: | |||||
| if isinstance(value, datetime.datetime): | |||||
| return value.isoformat() | |||||
| return value |
| """Schema for Blobs and Blob Loaders. | |||||
| The goal is to facilitate decoupling of content loading from content parsing code. | |||||
| In addition, content loading code should provide a lazy loading interface by default. | |||||
| """ | |||||
| from __future__ import annotations | |||||
| import contextlib | |||||
| import mimetypes | |||||
| from abc import ABC, abstractmethod | |||||
| from collections.abc import Generator, Iterable, Mapping | |||||
| from io import BufferedReader, BytesIO | |||||
| from pathlib import PurePath | |||||
| from typing import Any, Optional, Union | |||||
| from pydantic import BaseModel, root_validator | |||||
| PathLike = Union[str, PurePath] | |||||
| class Blob(BaseModel): | |||||
| """A blob is used to represent raw data by either reference or value. | |||||
| Provides an interface to materialize the blob in different representations, and | |||||
| help to decouple the development of data loaders from the downstream parsing of | |||||
| the raw data. | |||||
| Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob | |||||
| """ | |||||
| data: Union[bytes, str, None] # Raw data | |||||
| mimetype: Optional[str] = None # Not to be confused with a file extension | |||||
| encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string | |||||
| # Location where the original content was found | |||||
| # Represent location on the local file system | |||||
| # Useful for situations where downstream code assumes it must work with file paths | |||||
| # rather than in-memory content. | |||||
| path: Optional[PathLike] = None | |||||
| class Config: | |||||
| arbitrary_types_allowed = True | |||||
| frozen = True | |||||
| @property | |||||
| def source(self) -> Optional[str]: | |||||
| """The source location of the blob as string if known otherwise none.""" | |||||
| return str(self.path) if self.path else None | |||||
| @root_validator(pre=True) | |||||
| def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]: | |||||
| """Verify that either data or path is provided.""" | |||||
| if "data" not in values and "path" not in values: | |||||
| raise ValueError("Either data or path must be provided") | |||||
| return values | |||||
| def as_string(self) -> str: | |||||
| """Read data as a string.""" | |||||
| if self.data is None and self.path: | |||||
| with open(str(self.path), encoding=self.encoding) as f: | |||||
| return f.read() | |||||
| elif isinstance(self.data, bytes): | |||||
| return self.data.decode(self.encoding) | |||||
| elif isinstance(self.data, str): | |||||
| return self.data | |||||
| else: | |||||
| raise ValueError(f"Unable to get string for blob {self}") | |||||
| def as_bytes(self) -> bytes: | |||||
| """Read data as bytes.""" | |||||
| if isinstance(self.data, bytes): | |||||
| return self.data | |||||
| elif isinstance(self.data, str): | |||||
| return self.data.encode(self.encoding) | |||||
| elif self.data is None and self.path: | |||||
| with open(str(self.path), "rb") as f: | |||||
| return f.read() | |||||
| else: | |||||
| raise ValueError(f"Unable to get bytes for blob {self}") | |||||
| @contextlib.contextmanager | |||||
| def as_bytes_io(self) -> Generator[Union[BytesIO, BufferedReader], None, None]: | |||||
| """Read data as a byte stream.""" | |||||
| if isinstance(self.data, bytes): | |||||
| yield BytesIO(self.data) | |||||
| elif self.data is None and self.path: | |||||
| with open(str(self.path), "rb") as f: | |||||
| yield f | |||||
| else: | |||||
| raise NotImplementedError(f"Unable to convert blob {self}") | |||||
| @classmethod | |||||
| def from_path( | |||||
| cls, | |||||
| path: PathLike, | |||||
| *, | |||||
| encoding: str = "utf-8", | |||||
| mime_type: Optional[str] = None, | |||||
| guess_type: bool = True, | |||||
| ) -> Blob: | |||||
| """Load the blob from a path like object. | |||||
| Args: | |||||
| path: path like object to file to be read | |||||
| encoding: Encoding to use if decoding the bytes into a string | |||||
| mime_type: if provided, will be set as the mime-type of the data | |||||
| guess_type: If True, the mimetype will be guessed from the file extension, | |||||
| if a mime-type was not provided | |||||
| Returns: | |||||
| Blob instance | |||||
| """ | |||||
| if mime_type is None and guess_type: | |||||
| _mimetype = mimetypes.guess_type(path)[0] if guess_type else None | |||||
| else: | |||||
| _mimetype = mime_type | |||||
| # We do not load the data immediately, instead we treat the blob as a | |||||
| # reference to the underlying data. | |||||
| return cls(data=None, mimetype=_mimetype, encoding=encoding, path=path) | |||||
| @classmethod | |||||
| def from_data( | |||||
| cls, | |||||
| data: Union[str, bytes], | |||||
| *, | |||||
| encoding: str = "utf-8", | |||||
| mime_type: Optional[str] = None, | |||||
| path: Optional[str] = None, | |||||
| ) -> Blob: | |||||
| """Initialize the blob from in-memory data. | |||||
| Args: | |||||
| data: the in-memory data associated with the blob | |||||
| encoding: Encoding to use if decoding the bytes into a string | |||||
| mime_type: if provided, will be set as the mime-type of the data | |||||
| path: if provided, will be set as the source from which the data came | |||||
| Returns: | |||||
| Blob instance | |||||
| """ | |||||
| return cls(data=data, mimetype=mime_type, encoding=encoding, path=path) | |||||
| def __repr__(self) -> str: | |||||
| """Define the blob representation.""" | |||||
| str_repr = f"Blob {id(self)}" | |||||
| if self.source: | |||||
| str_repr += f" {self.source}" | |||||
| return str_repr | |||||
| class BlobLoader(ABC): | |||||
| """Abstract interface for blob loaders implementation. | |||||
| Implementer should be able to load raw content from a datasource system according | |||||
| to some criteria and return the raw content lazily as a stream of blobs. | |||||
| """ | |||||
| @abstractmethod | |||||
| def yield_blobs( | |||||
| self, | |||||
| ) -> Iterable[Blob]: | |||||
| """A lazy loader for raw data represented by LangChain's Blob object. | |||||
| Returns: | |||||
| A generator over blobs | |||||
| """ |
| """Abstract interface for document loader implementations.""" | |||||
| import csv | |||||
| from typing import Optional | |||||
| from core.rag.extractor.extractor_base import BaseExtractor | |||||
| from core.rag.models.document import Document | |||||
| class CSVExtractor(BaseExtractor): | |||||
| """Load CSV files. | |||||
| Args: | |||||
| file_path: Path to the file to load. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| file_path: str, | |||||
| encoding: Optional[str] = None, | |||||
| autodetect_encoding: bool = False, | |||||
| source_column: Optional[str] = None, | |||||
| csv_args: Optional[dict] = None, | |||||
| ): | |||||
| """Initialize with file path.""" | |||||
| self._file_path = file_path | |||||
| self._encoding = encoding | |||||
| self._autodetect_encoding = autodetect_encoding | |||||
| self.source_column = source_column | |||||
| self.csv_args = csv_args or {} | |||||
| def extract(self) -> list[Document]: | |||||
| """Load data into document objects.""" | |||||
| try: | |||||
| with open(self._file_path, newline="", encoding=self._encoding) as csvfile: | |||||
| docs = self._read_from_file(csvfile) | |||||
| except UnicodeDecodeError as e: | |||||
| if self._autodetect_encoding: | |||||
| detected_encodings = detect_filze_encodings(self._file_path) | |||||
| for encoding in detected_encodings: | |||||
| try: | |||||
| with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile: | |||||
| docs = self._read_from_file(csvfile) | |||||
| break | |||||
| except UnicodeDecodeError: | |||||
| continue | |||||
| else: | |||||
| raise RuntimeError(f"Error loading {self._file_path}") from e | |||||
| return docs | |||||
| def _read_from_file(self, csvfile) -> list[Document]: | |||||
| docs = [] | |||||
| csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore | |||||
| for i, row in enumerate(csv_reader): | |||||
| content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items()) | |||||
| try: | |||||
| source = ( | |||||
| row[self.source_column] | |||||
| if self.source_column is not None | |||||
| else '' | |||||
| ) | |||||
| except KeyError: | |||||
| raise ValueError( | |||||
| f"Source column '{self.source_column}' not found in CSV file." | |||||
| ) | |||||
| metadata = {"source": source, "row": i} | |||||
| doc = Document(page_content=content, metadata=metadata) | |||||
| docs.append(doc) | |||||
| return docs |
| from enum import Enum | |||||
| class DatasourceType(Enum): | |||||
| FILE = "upload_file" | |||||
| NOTION = "notion_import" |
| from pydantic import BaseModel | |||||
| from models.dataset import Document | |||||
| from models.model import UploadFile | |||||
| class NotionInfo(BaseModel): | |||||
| """ | |||||
| Notion import info. | |||||
| """ | |||||
| notion_workspace_id: str | |||||
| notion_obj_id: str | |||||
| notion_page_type: str | |||||
| document: Document = None | |||||
| class Config: | |||||
| arbitrary_types_allowed = True | |||||
| def __init__(self, **data) -> None: | |||||
| super().__init__(**data) | |||||
| class ExtractSetting(BaseModel): | |||||
| """ | |||||
| Model class for provider response. | |||||
| """ | |||||
| datasource_type: str | |||||
| upload_file: UploadFile = None | |||||
| notion_info: NotionInfo = None | |||||
| document_model: str = None | |||||
| class Config: | |||||
| arbitrary_types_allowed = True | |||||
| def __init__(self, **data) -> None: | |||||
| super().__init__(**data) |
| import logging | |||||
| """Abstract interface for document loader implementations.""" | |||||
| from typing import Optional | |||||
| from langchain.document_loaders.base import BaseLoader | |||||
| from langchain.schema import Document | |||||
| from openpyxl.reader.excel import load_workbook | from openpyxl.reader.excel import load_workbook | ||||
| logger = logging.getLogger(__name__) | |||||
| from core.rag.extractor.extractor_base import BaseExtractor | |||||
| from core.rag.models.document import Document | |||||
| class ExcelLoader(BaseLoader): | |||||
| """Load xlxs files. | |||||
| class ExcelExtractor(BaseExtractor): | |||||
| """Load Excel files. | |||||
| Args: | Args: | ||||
| """ | """ | ||||
| def __init__( | def __init__( | ||||
| self, | |||||
| file_path: str | |||||
| self, | |||||
| file_path: str, | |||||
| encoding: Optional[str] = None, | |||||
| autodetect_encoding: bool = False | |||||
| ): | ): | ||||
| """Initialize with file path.""" | """Initialize with file path.""" | ||||
| self._file_path = file_path | self._file_path = file_path | ||||
| self._encoding = encoding | |||||
| self._autodetect_encoding = autodetect_encoding | |||||
| def load(self) -> list[Document]: | |||||
| def extract(self) -> list[Document]: | |||||
| """Load from file path.""" | |||||
| data = [] | data = [] | ||||
| keys = [] | keys = [] | ||||
| wb = load_workbook(filename=self._file_path, read_only=True) | wb = load_workbook(filename=self._file_path, read_only=True) |
| import tempfile | |||||
| from pathlib import Path | |||||
| from typing import Union | |||||
| import requests | |||||
| from flask import current_app | |||||
| from core.rag.extractor.csv_extractor import CSVExtractor | |||||
| from core.rag.extractor.entity.datasource_type import DatasourceType | |||||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||||
| from core.rag.extractor.excel_extractor import ExcelExtractor | |||||
| from core.rag.extractor.html_extractor import HtmlExtractor | |||||
| from core.rag.extractor.markdown_extractor import MarkdownExtractor | |||||
| from core.rag.extractor.notion_extractor import NotionExtractor | |||||
| from core.rag.extractor.pdf_extractor import PdfExtractor | |||||
| from core.rag.extractor.text_extractor import TextExtractor | |||||
| from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor | |||||
| from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor | |||||
| from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor | |||||
| from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor | |||||
| from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor | |||||
| from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor | |||||
| from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor | |||||
| from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor | |||||
| from core.rag.extractor.word_extractor import WordExtractor | |||||
| from core.rag.models.document import Document | |||||
| from extensions.ext_storage import storage | |||||
| from models.model import UploadFile | |||||
| SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain'] | |||||
| USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" | |||||
| class ExtractProcessor: | |||||
| @classmethod | |||||
| def load_from_upload_file(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) \ | |||||
| -> Union[list[Document], str]: | |||||
| extract_setting = ExtractSetting( | |||||
| datasource_type="upload_file", | |||||
| upload_file=upload_file, | |||||
| document_model='text_model' | |||||
| ) | |||||
| if return_text: | |||||
| delimiter = '\n' | |||||
| return delimiter.join([document.page_content for document in cls.extract(extract_setting, is_automatic)]) | |||||
| else: | |||||
| return cls.extract(extract_setting, is_automatic) | |||||
| @classmethod | |||||
| def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]: | |||||
| response = requests.get(url, headers={ | |||||
| "User-Agent": USER_AGENT | |||||
| }) | |||||
| with tempfile.TemporaryDirectory() as temp_dir: | |||||
| suffix = Path(url).suffix | |||||
| file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" | |||||
| with open(file_path, 'wb') as file: | |||||
| file.write(response.content) | |||||
| extract_setting = ExtractSetting( | |||||
| datasource_type="upload_file", | |||||
| document_model='text_model' | |||||
| ) | |||||
| if return_text: | |||||
| delimiter = '\n' | |||||
| return delimiter.join([document.page_content for document in cls.extract( | |||||
| extract_setting=extract_setting, file_path=file_path)]) | |||||
| else: | |||||
| return cls.extract(extract_setting=extract_setting, file_path=file_path) | |||||
| @classmethod | |||||
| def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False, | |||||
| file_path: str = None) -> list[Document]: | |||||
| if extract_setting.datasource_type == DatasourceType.FILE.value: | |||||
| with tempfile.TemporaryDirectory() as temp_dir: | |||||
| if not file_path: | |||||
| upload_file: UploadFile = extract_setting.upload_file | |||||
| suffix = Path(upload_file.key).suffix | |||||
| file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" | |||||
| storage.download(upload_file.key, file_path) | |||||
| input_file = Path(file_path) | |||||
| file_extension = input_file.suffix.lower() | |||||
| etl_type = current_app.config['ETL_TYPE'] | |||||
| unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL'] | |||||
| if etl_type == 'Unstructured': | |||||
| if file_extension == '.xlsx': | |||||
| extractor = ExcelExtractor(file_path) | |||||
| elif file_extension == '.pdf': | |||||
| extractor = PdfExtractor(file_path) | |||||
| elif file_extension in ['.md', '.markdown']: | |||||
| extractor = UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic \ | |||||
| else MarkdownExtractor(file_path, autodetect_encoding=True) | |||||
| elif file_extension in ['.htm', '.html']: | |||||
| extractor = HtmlExtractor(file_path) | |||||
| elif file_extension in ['.docx']: | |||||
| extractor = UnstructuredWordExtractor(file_path, unstructured_api_url) | |||||
| elif file_extension == '.csv': | |||||
| extractor = CSVExtractor(file_path, autodetect_encoding=True) | |||||
| elif file_extension == '.msg': | |||||
| extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url) | |||||
| elif file_extension == '.eml': | |||||
| extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url) | |||||
| elif file_extension == '.ppt': | |||||
| extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url) | |||||
| elif file_extension == '.pptx': | |||||
| extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url) | |||||
| elif file_extension == '.xml': | |||||
| extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url) | |||||
| else: | |||||
| # txt | |||||
| extractor = UnstructuredTextExtractor(file_path, unstructured_api_url) if is_automatic \ | |||||
| else TextExtractor(file_path, autodetect_encoding=True) | |||||
| else: | |||||
| if file_extension == '.xlsx': | |||||
| extractor = ExcelExtractor(file_path) | |||||
| elif file_extension == '.pdf': | |||||
| extractor = PdfExtractor(file_path) | |||||
| elif file_extension in ['.md', '.markdown']: | |||||
| extractor = MarkdownExtractor(file_path, autodetect_encoding=True) | |||||
| elif file_extension in ['.htm', '.html']: | |||||
| extractor = HtmlExtractor(file_path) | |||||
| elif file_extension in ['.docx']: | |||||
| extractor = WordExtractor(file_path) | |||||
| elif file_extension == '.csv': | |||||
| extractor = CSVExtractor(file_path, autodetect_encoding=True) | |||||
| else: | |||||
| # txt | |||||
| extractor = TextExtractor(file_path, autodetect_encoding=True) | |||||
| return extractor.extract() | |||||
| elif extract_setting.datasource_type == DatasourceType.NOTION.value: | |||||
| extractor = NotionExtractor( | |||||
| notion_workspace_id=extract_setting.notion_info.notion_workspace_id, | |||||
| notion_obj_id=extract_setting.notion_info.notion_obj_id, | |||||
| notion_page_type=extract_setting.notion_info.notion_page_type, | |||||
| document_model=extract_setting.notion_info.document | |||||
| ) | |||||
| return extractor.extract() | |||||
| else: | |||||
| raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}") |
| """Abstract interface for document loader implementations.""" | |||||
| from abc import ABC, abstractmethod | |||||
| class BaseExtractor(ABC): | |||||
| """Interface for extract files. | |||||
| """ | |||||
| @abstractmethod | |||||
| def extract(self): | |||||
| raise NotImplementedError | |||||
| """Document loader helpers.""" | |||||
| import concurrent.futures | |||||
| from typing import NamedTuple, Optional, cast | |||||
| class FileEncoding(NamedTuple): | |||||
| """A file encoding as the NamedTuple.""" | |||||
| encoding: Optional[str] | |||||
| """The encoding of the file.""" | |||||
| confidence: float | |||||
| """The confidence of the encoding.""" | |||||
| language: Optional[str] | |||||
| """The language of the file.""" | |||||
| def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding]: | |||||
| """Try to detect the file encoding. | |||||
| Returns a list of `FileEncoding` tuples with the detected encodings ordered | |||||
| by confidence. | |||||
| Args: | |||||
| file_path: The path to the file to detect the encoding for. | |||||
| timeout: The timeout in seconds for the encoding detection. | |||||
| """ | |||||
| import chardet | |||||
| def read_and_detect(file_path: str) -> list[dict]: | |||||
| with open(file_path, "rb") as f: | |||||
| rawdata = f.read() | |||||
| return cast(list[dict], chardet.detect_all(rawdata)) | |||||
| with concurrent.futures.ThreadPoolExecutor() as executor: | |||||
| future = executor.submit(read_and_detect, file_path) | |||||
| try: | |||||
| encodings = future.result(timeout=timeout) | |||||
| except concurrent.futures.TimeoutError: | |||||
| raise TimeoutError( | |||||
| f"Timeout reached while detecting encoding for {file_path}" | |||||
| ) | |||||
| if all(encoding["encoding"] is None for encoding in encodings): | |||||
| raise RuntimeError(f"Could not detect encoding for {file_path}") | |||||
| return [FileEncoding(**enc) for enc in encodings if enc["encoding"] is not None] |
| import csv | |||||
| import logging | |||||
| """Abstract interface for document loader implementations.""" | |||||
| from typing import Optional | from typing import Optional | ||||
| from langchain.document_loaders import CSVLoader as LCCSVLoader | |||||
| from langchain.document_loaders.helpers import detect_file_encodings | |||||
| from langchain.schema import Document | |||||
| from core.rag.extractor.extractor_base import BaseExtractor | |||||
| from core.rag.extractor.helpers import detect_file_encodings | |||||
| from core.rag.models.document import Document | |||||
| logger = logging.getLogger(__name__) | |||||
| class HtmlExtractor(BaseExtractor): | |||||
| """Load html files. | |||||
| Args: | |||||
| file_path: Path to the file to load. | |||||
| """ | |||||
| class CSVLoader(LCCSVLoader): | |||||
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| file_path: str, | file_path: str, | ||||
| encoding: Optional[str] = None, | |||||
| autodetect_encoding: bool = False, | |||||
| source_column: Optional[str] = None, | source_column: Optional[str] = None, | ||||
| csv_args: Optional[dict] = None, | csv_args: Optional[dict] = None, | ||||
| encoding: Optional[str] = None, | |||||
| autodetect_encoding: bool = True, | |||||
| ): | ): | ||||
| self.file_path = file_path | |||||
| """Initialize with file path.""" | |||||
| self._file_path = file_path | |||||
| self._encoding = encoding | |||||
| self._autodetect_encoding = autodetect_encoding | |||||
| self.source_column = source_column | self.source_column = source_column | ||||
| self.encoding = encoding | |||||
| self.csv_args = csv_args or {} | self.csv_args = csv_args or {} | ||||
| self.autodetect_encoding = autodetect_encoding | |||||
| def load(self) -> list[Document]: | |||||
| def extract(self) -> list[Document]: | |||||
| """Load data into document objects.""" | """Load data into document objects.""" | ||||
| try: | try: | ||||
| with open(self.file_path, newline="", encoding=self.encoding) as csvfile: | |||||
| with open(self._file_path, newline="", encoding=self._encoding) as csvfile: | |||||
| docs = self._read_from_file(csvfile) | docs = self._read_from_file(csvfile) | ||||
| except UnicodeDecodeError as e: | except UnicodeDecodeError as e: | ||||
| if self.autodetect_encoding: | |||||
| detected_encodings = detect_file_encodings(self.file_path) | |||||
| if self._autodetect_encoding: | |||||
| detected_encodings = detect_file_encodings(self._file_path) | |||||
| for encoding in detected_encodings: | for encoding in detected_encodings: | ||||
| logger.debug("Trying encoding: ", encoding.encoding) | |||||
| try: | try: | ||||
| with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile: | |||||
| with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile: | |||||
| docs = self._read_from_file(csvfile) | docs = self._read_from_file(csvfile) | ||||
| break | break | ||||
| except UnicodeDecodeError: | except UnicodeDecodeError: | ||||
| continue | continue | ||||
| else: | else: | ||||
| raise RuntimeError(f"Error loading {self.file_path}") from e | |||||
| raise RuntimeError(f"Error loading {self._file_path}") from e | |||||
| return docs | return docs | ||||
| def _read_from_file(self, csvfile): | |||||
| def _read_from_file(self, csvfile) -> list[Document]: | |||||
| docs = [] | docs = [] | ||||
| csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore | csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore | ||||
| for i, row in enumerate(csv_reader): | for i, row in enumerate(csv_reader): |
| import logging | |||||
| """Abstract interface for document loader implementations.""" | |||||
| import re | import re | ||||
| from typing import Optional, cast | from typing import Optional, cast | ||||
| from langchain.document_loaders.base import BaseLoader | |||||
| from langchain.document_loaders.helpers import detect_file_encodings | |||||
| from langchain.schema import Document | |||||
| from core.rag.extractor.extractor_base import BaseExtractor | |||||
| from core.rag.extractor.helpers import detect_file_encodings | |||||
| from core.rag.models.document import Document | |||||
| logger = logging.getLogger(__name__) | |||||
| class MarkdownLoader(BaseLoader): | |||||
| """Load md files. | |||||
| class MarkdownExtractor(BaseExtractor): | |||||
| """Load Markdown files. | |||||
| Args: | Args: | ||||
| file_path: Path to the file to load. | file_path: Path to the file to load. | ||||
| remove_hyperlinks: Whether to remove hyperlinks from the text. | |||||
| remove_images: Whether to remove images from the text. | |||||
| encoding: File encoding to use. If `None`, the file will be loaded | |||||
| with the default system encoding. | |||||
| autodetect_encoding: Whether to try to autodetect the file encoding | |||||
| if the specified encoding fails. | |||||
| """ | """ | ||||
| def __init__( | def __init__( | ||||
| self, | |||||
| file_path: str, | |||||
| remove_hyperlinks: bool = True, | |||||
| remove_images: bool = True, | |||||
| encoding: Optional[str] = None, | |||||
| autodetect_encoding: bool = True, | |||||
| self, | |||||
| file_path: str, | |||||
| remove_hyperlinks: bool = True, | |||||
| remove_images: bool = True, | |||||
| encoding: Optional[str] = None, | |||||
| autodetect_encoding: bool = True, | |||||
| ): | ): | ||||
| """Initialize with file path.""" | """Initialize with file path.""" | ||||
| self._file_path = file_path | self._file_path = file_path | ||||
| self._encoding = encoding | self._encoding = encoding | ||||
| self._autodetect_encoding = autodetect_encoding | self._autodetect_encoding = autodetect_encoding | ||||
| def load(self) -> list[Document]: | |||||
| def extract(self) -> list[Document]: | |||||
| """Load from file path.""" | |||||
| tups = self.parse_tups(self._file_path) | tups = self.parse_tups(self._file_path) | ||||
| documents = [] | documents = [] | ||||
| for header, value in tups: | for header, value in tups: | ||||
| if self._autodetect_encoding: | if self._autodetect_encoding: | ||||
| detected_encodings = detect_file_encodings(filepath) | detected_encodings = detect_file_encodings(filepath) | ||||
| for encoding in detected_encodings: | for encoding in detected_encodings: | ||||
| logger.debug("Trying encoding: ", encoding.encoding) | |||||
| try: | try: | ||||
| with open(filepath, encoding=encoding.encoding) as f: | with open(filepath, encoding=encoding.encoding) as f: | ||||
| content = f.read() | content = f.read() |
| import requests | import requests | ||||
| from flask import current_app | from flask import current_app | ||||
| from langchain.document_loaders.base import BaseLoader | |||||
| from flask_login import current_user | |||||
| from langchain.schema import Document | from langchain.schema import Document | ||||
| from core.rag.extractor.extractor_base import BaseExtractor | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import Document as DocumentModel | from models.dataset import Document as DocumentModel | ||||
| from models.source import DataSourceBinding | from models.source import DataSourceBinding | ||||
| HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3'] | HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3'] | ||||
| class NotionLoader(BaseLoader): | |||||
| class NotionExtractor(BaseExtractor): | |||||
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| notion_access_token: str, | |||||
| notion_workspace_id: str, | notion_workspace_id: str, | ||||
| notion_obj_id: str, | notion_obj_id: str, | ||||
| notion_page_type: str, | notion_page_type: str, | ||||
| document_model: Optional[DocumentModel] = None | |||||
| document_model: Optional[DocumentModel] = None, | |||||
| notion_access_token: Optional[str] = None | |||||
| ): | ): | ||||
| self._notion_access_token = None | |||||
| self._document_model = document_model | self._document_model = document_model | ||||
| self._notion_workspace_id = notion_workspace_id | self._notion_workspace_id = notion_workspace_id | ||||
| self._notion_obj_id = notion_obj_id | self._notion_obj_id = notion_obj_id | ||||
| self._notion_page_type = notion_page_type | self._notion_page_type = notion_page_type | ||||
| self._notion_access_token = notion_access_token | |||||
| if not self._notion_access_token: | |||||
| integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN') | |||||
| if integration_token is None: | |||||
| raise ValueError( | |||||
| "Must specify `integration_token` or set environment " | |||||
| "variable `NOTION_INTEGRATION_TOKEN`." | |||||
| ) | |||||
| self._notion_access_token = integration_token | |||||
| @classmethod | |||||
| def from_document(cls, document_model: DocumentModel): | |||||
| data_source_info = document_model.data_source_info_dict | |||||
| if not data_source_info or 'notion_page_id' not in data_source_info \ | |||||
| or 'notion_workspace_id' not in data_source_info: | |||||
| raise ValueError("no notion page found") | |||||
| notion_workspace_id = data_source_info['notion_workspace_id'] | |||||
| notion_obj_id = data_source_info['notion_page_id'] | |||||
| notion_page_type = data_source_info['type'] | |||||
| notion_access_token = cls._get_access_token(document_model.tenant_id, notion_workspace_id) | |||||
| return cls( | |||||
| notion_access_token=notion_access_token, | |||||
| notion_workspace_id=notion_workspace_id, | |||||
| notion_obj_id=notion_obj_id, | |||||
| notion_page_type=notion_page_type, | |||||
| document_model=document_model | |||||
| ) | |||||
| def load(self) -> list[Document]: | |||||
| if notion_access_token: | |||||
| self._notion_access_token = notion_access_token | |||||
| else: | |||||
| self._notion_access_token = self._get_access_token(current_user.current_tenant_id, | |||||
| self._notion_workspace_id) | |||||
| if not self._notion_access_token: | |||||
| integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN') | |||||
| if integration_token is None: | |||||
| raise ValueError( | |||||
| "Must specify `integration_token` or set environment " | |||||
| "variable `NOTION_INTEGRATION_TOKEN`." | |||||
| ) | |||||
| self._notion_access_token = integration_token | |||||
| def extract(self) -> list[Document]: | |||||
| self.update_last_edited_time( | self.update_last_edited_time( | ||||
| self._document_model | self._document_model | ||||
| ) | ) |
| """Abstract interface for document loader implementations.""" | |||||
| from collections.abc import Iterator | |||||
| from typing import Optional | |||||
| from core.rag.extractor.blod.blod import Blob | |||||
| from core.rag.extractor.extractor_base import BaseExtractor | |||||
| from core.rag.models.document import Document | |||||
| from extensions.ext_storage import storage | |||||
| class PdfExtractor(BaseExtractor): | |||||
| """Load pdf files. | |||||
| Args: | |||||
| file_path: Path to the file to load. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| file_path: str, | |||||
| file_cache_key: Optional[str] = None | |||||
| ): | |||||
| """Initialize with file path.""" | |||||
| self._file_path = file_path | |||||
| self._file_cache_key = file_cache_key | |||||
| def extract(self) -> list[Document]: | |||||
| plaintext_file_key = '' | |||||
| plaintext_file_exists = False | |||||
| if self._file_cache_key: | |||||
| try: | |||||
| text = storage.load(self._file_cache_key).decode('utf-8') | |||||
| plaintext_file_exists = True | |||||
| return [Document(page_content=text)] | |||||
| except FileNotFoundError: | |||||
| pass | |||||
| documents = list(self.load()) | |||||
| text_list = [] | |||||
| for document in documents: | |||||
| text_list.append(document.page_content) | |||||
| text = "\n\n".join(text_list) | |||||
| # save plaintext file for caching | |||||
| if not plaintext_file_exists and plaintext_file_key: | |||||
| storage.save(plaintext_file_key, text.encode('utf-8')) | |||||
| return documents | |||||
| def load( | |||||
| self, | |||||
| ) -> Iterator[Document]: | |||||
| """Lazy load given path as pages.""" | |||||
| blob = Blob.from_path(self._file_path) | |||||
| yield from self.parse(blob) | |||||
| def parse(self, blob: Blob) -> Iterator[Document]: | |||||
| """Lazily parse the blob.""" | |||||
| import pypdfium2 | |||||
| with blob.as_bytes_io() as file_path: | |||||
| pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True) | |||||
| try: | |||||
| for page_number, page in enumerate(pdf_reader): | |||||
| text_page = page.get_textpage() | |||||
| content = text_page.get_text_range() | |||||
| text_page.close() | |||||
| page.close() | |||||
| metadata = {"source": blob.source, "page": page_number} | |||||
| yield Document(page_content=content, metadata=metadata) | |||||
| finally: | |||||
| pdf_reader.close() |
| """Abstract interface for document loader implementations.""" | |||||
| from typing import Optional | |||||
| from core.rag.extractor.extractor_base import BaseExtractor | |||||
| from core.rag.extractor.helpers import detect_file_encodings | |||||
| from core.rag.models.document import Document | |||||
| class TextExtractor(BaseExtractor): | |||||
| """Load text files. | |||||
| Args: | |||||
| file_path: Path to the file to load. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| file_path: str, | |||||
| encoding: Optional[str] = None, | |||||
| autodetect_encoding: bool = False | |||||
| ): | |||||
| """Initialize with file path.""" | |||||
| self._file_path = file_path | |||||
| self._encoding = encoding | |||||
| self._autodetect_encoding = autodetect_encoding | |||||
| def extract(self) -> list[Document]: | |||||
| """Load from file path.""" | |||||
| text = "" | |||||
| try: | |||||
| with open(self._file_path, encoding=self._encoding) as f: | |||||
| text = f.read() | |||||
| except UnicodeDecodeError as e: | |||||
| if self._autodetect_encoding: | |||||
| detected_encodings = detect_file_encodings(self._file_path) | |||||
| for encoding in detected_encodings: | |||||
| try: | |||||
| with open(self._file_path, encoding=encoding.encoding) as f: | |||||
| text = f.read() | |||||
| break | |||||
| except UnicodeDecodeError: | |||||
| continue | |||||
| else: | |||||
| raise RuntimeError(f"Error loading {self._file_path}") from e | |||||
| except Exception as e: | |||||
| raise RuntimeError(f"Error loading {self._file_path}") from e | |||||
| metadata = {"source": self._file_path} | |||||
| return [Document(page_content=text, metadata=metadata)] |
| import logging | |||||
| import os | |||||
| from core.rag.extractor.extractor_base import BaseExtractor | |||||
| from core.rag.models.document import Document | |||||
| logger = logging.getLogger(__name__) | |||||
| class UnstructuredWordExtractor(BaseExtractor): | |||||
| """Loader that uses unstructured to load word documents. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| file_path: str, | |||||
| api_url: str, | |||||
| ): | |||||
| """Initialize with file path.""" | |||||
| self._file_path = file_path | |||||
| self._api_url = api_url | |||||
| def extract(self) -> list[Document]: | |||||
| from unstructured.__version__ import __version__ as __unstructured_version__ | |||||
| from unstructured.file_utils.filetype import FileType, detect_filetype | |||||
| unstructured_version = tuple( | |||||
| [int(x) for x in __unstructured_version__.split(".")] | |||||
| ) | |||||
| # check the file extension | |||||
| try: | |||||
| import magic # noqa: F401 | |||||
| is_doc = detect_filetype(self._file_path) == FileType.DOC | |||||
| except ImportError: | |||||
| _, extension = os.path.splitext(str(self._file_path)) | |||||
| is_doc = extension == ".doc" | |||||
| if is_doc and unstructured_version < (0, 4, 11): | |||||
| raise ValueError( | |||||
| f"You are on unstructured version {__unstructured_version__}. " | |||||
| "Partitioning .doc files is only supported in unstructured>=0.4.11. " | |||||
| "Please upgrade the unstructured package and try again." | |||||
| ) | |||||
| if is_doc: | |||||
| from unstructured.partition.doc import partition_doc | |||||
| elements = partition_doc(filename=self._file_path) | |||||
| else: | |||||
| from unstructured.partition.docx import partition_docx | |||||
| elements = partition_docx(filename=self._file_path) | |||||
| from unstructured.chunking.title import chunk_by_title | |||||
| chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=0) | |||||
| documents = [] | |||||
| for chunk in chunks: | |||||
| text = chunk.text.strip() | |||||
| documents.append(Document(page_content=text)) | |||||
| return documents |
| import logging | import logging | ||||
| from bs4 import BeautifulSoup | from bs4 import BeautifulSoup | ||||
| from langchain.document_loaders.base import BaseLoader | |||||
| from langchain.schema import Document | |||||
| from core.rag.extractor.extractor_base import BaseExtractor | |||||
| from core.rag.models.document import Document | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| class UnstructuredEmailLoader(BaseLoader): | |||||
| class UnstructuredEmailExtractor(BaseExtractor): | |||||
| """Load msg files. | """Load msg files. | ||||
| Args: | Args: | ||||
| file_path: Path to the file to load. | file_path: Path to the file to load. | ||||
| self._file_path = file_path | self._file_path = file_path | ||||
| self._api_url = api_url | self._api_url = api_url | ||||
| def load(self) -> list[Document]: | |||||
| def extract(self) -> list[Document]: | |||||
| from unstructured.partition.email import partition_email | from unstructured.partition.email import partition_email | ||||
| elements = partition_email(filename=self._file_path, api_url=self._api_url) | elements = partition_email(filename=self._file_path, api_url=self._api_url) | ||||
| import logging | import logging | ||||
| from langchain.document_loaders.base import BaseLoader | |||||
| from langchain.schema import Document | |||||
| from core.rag.extractor.extractor_base import BaseExtractor | |||||
| from core.rag.models.document import Document | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| class UnstructuredMarkdownLoader(BaseLoader): | |||||
| class UnstructuredMarkdownExtractor(BaseExtractor): | |||||
| """Load md files. | """Load md files. | ||||
| self._file_path = file_path | self._file_path = file_path | ||||
| self._api_url = api_url | self._api_url = api_url | ||||
| def load(self) -> list[Document]: | |||||
| def extract(self) -> list[Document]: | |||||
| from unstructured.partition.md import partition_md | from unstructured.partition.md import partition_md | ||||
| elements = partition_md(filename=self._file_path, api_url=self._api_url) | elements = partition_md(filename=self._file_path, api_url=self._api_url) |
| import logging | import logging | ||||
| from langchain.document_loaders.base import BaseLoader | |||||
| from langchain.schema import Document | |||||
| from core.rag.extractor.extractor_base import BaseExtractor | |||||
| from core.rag.models.document import Document | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| class UnstructuredMsgLoader(BaseLoader): | |||||
| class UnstructuredMsgExtractor(BaseExtractor): | |||||
| """Load msg files. | """Load msg files. | ||||
| self._file_path = file_path | self._file_path = file_path | ||||
| self._api_url = api_url | self._api_url = api_url | ||||
| def load(self) -> list[Document]: | |||||
| def extract(self) -> list[Document]: | |||||
| from unstructured.partition.msg import partition_msg | from unstructured.partition.msg import partition_msg | ||||
| elements = partition_msg(filename=self._file_path, api_url=self._api_url) | elements = partition_msg(filename=self._file_path, api_url=self._api_url) |
| import logging | import logging | ||||
| from langchain.document_loaders.base import BaseLoader | |||||
| from langchain.schema import Document | |||||
| from core.rag.extractor.extractor_base import BaseExtractor | |||||
| from core.rag.models.document import Document | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| class UnstructuredPPTLoader(BaseLoader): | |||||
| class UnstructuredPPTExtractor(BaseExtractor): | |||||
| """Load msg files. | """Load msg files. | ||||
| """ | """ | ||||
| def __init__( | def __init__( | ||||
| self, | |||||
| file_path: str, | |||||
| api_url: str | |||||
| self, | |||||
| file_path: str, | |||||
| api_url: str | |||||
| ): | ): | ||||
| """Initialize with file path.""" | """Initialize with file path.""" | ||||
| self._file_path = file_path | self._file_path = file_path | ||||
| self._api_url = api_url | self._api_url = api_url | ||||
| def load(self) -> list[Document]: | |||||
| def extract(self) -> list[Document]: | |||||
| from unstructured.partition.ppt import partition_ppt | from unstructured.partition.ppt import partition_ppt | ||||
| elements = partition_ppt(filename=self._file_path, api_url=self._api_url) | elements = partition_ppt(filename=self._file_path, api_url=self._api_url) |
| import logging | import logging | ||||
| from langchain.document_loaders.base import BaseLoader | |||||
| from langchain.schema import Document | |||||
| from core.rag.extractor.extractor_base import BaseExtractor | |||||
| from core.rag.models.document import Document | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| class UnstructuredPPTXLoader(BaseLoader): | |||||
| class UnstructuredPPTXExtractor(BaseExtractor): | |||||
| """Load msg files. | """Load msg files. | ||||
| """ | """ | ||||
| def __init__( | def __init__( | ||||
| self, | |||||
| file_path: str, | |||||
| api_url: str | |||||
| self, | |||||
| file_path: str, | |||||
| api_url: str | |||||
| ): | ): | ||||
| """Initialize with file path.""" | """Initialize with file path.""" | ||||
| self._file_path = file_path | self._file_path = file_path | ||||
| self._api_url = api_url | self._api_url = api_url | ||||
| def load(self) -> list[Document]: | |||||
| def extract(self) -> list[Document]: | |||||
| from unstructured.partition.pptx import partition_pptx | from unstructured.partition.pptx import partition_pptx | ||||
| elements = partition_pptx(filename=self._file_path, api_url=self._api_url) | elements = partition_pptx(filename=self._file_path, api_url=self._api_url) |
| import logging | import logging | ||||
| from langchain.document_loaders.base import BaseLoader | |||||
| from langchain.schema import Document | |||||
| from core.rag.extractor.extractor_base import BaseExtractor | |||||
| from core.rag.models.document import Document | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| class UnstructuredTextLoader(BaseLoader): | |||||
| class UnstructuredTextExtractor(BaseExtractor): | |||||
| """Load msg files. | """Load msg files. | ||||
| self._file_path = file_path | self._file_path = file_path | ||||
| self._api_url = api_url | self._api_url = api_url | ||||
| def load(self) -> list[Document]: | |||||
| def extract(self) -> list[Document]: | |||||
| from unstructured.partition.text import partition_text | from unstructured.partition.text import partition_text | ||||
| elements = partition_text(filename=self._file_path, api_url=self._api_url) | elements = partition_text(filename=self._file_path, api_url=self._api_url) |
| import logging | import logging | ||||
| from langchain.document_loaders.base import BaseLoader | |||||
| from langchain.schema import Document | |||||
| from core.rag.extractor.extractor_base import BaseExtractor | |||||
| from core.rag.models.document import Document | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| class UnstructuredXmlLoader(BaseLoader): | |||||
| class UnstructuredXmlExtractor(BaseExtractor): | |||||
| """Load msg files. | """Load msg files. | ||||
| self._file_path = file_path | self._file_path = file_path | ||||
| self._api_url = api_url | self._api_url = api_url | ||||
| def load(self) -> list[Document]: | |||||
| def extract(self) -> list[Document]: | |||||
| from unstructured.partition.xml import partition_xml | from unstructured.partition.xml import partition_xml | ||||
| elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url) | elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url) |
| """Abstract interface for document loader implementations.""" | |||||
| import os | |||||
| import tempfile | |||||
| from urllib.parse import urlparse | |||||
| import requests | |||||
| from core.rag.extractor.extractor_base import BaseExtractor | |||||
| from core.rag.models.document import Document | |||||
| class WordExtractor(BaseExtractor): | |||||
| """Load pdf files. | |||||
| Args: | |||||
| file_path: Path to the file to load. | |||||
| """ | |||||
| def __init__(self, file_path: str): | |||||
| """Initialize with file path.""" | |||||
| self.file_path = file_path | |||||
| if "~" in self.file_path: | |||||
| self.file_path = os.path.expanduser(self.file_path) | |||||
| # If the file is a web path, download it to a temporary file, and use that | |||||
| if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path): | |||||
| r = requests.get(self.file_path) | |||||
| if r.status_code != 200: | |||||
| raise ValueError( | |||||
| "Check the url of your file; returned status code %s" | |||||
| % r.status_code | |||||
| ) | |||||
| self.web_path = self.file_path | |||||
| self.temp_file = tempfile.NamedTemporaryFile() | |||||
| self.temp_file.write(r.content) | |||||
| self.file_path = self.temp_file.name | |||||
| elif not os.path.isfile(self.file_path): | |||||
| raise ValueError("File path %s is not a valid file or url" % self.file_path) | |||||
| def __del__(self) -> None: | |||||
| if hasattr(self, "temp_file"): | |||||
| self.temp_file.close() | |||||
| def extract(self) -> list[Document]: | |||||
| """Load given path as single page.""" | |||||
| import docx2txt | |||||
| return [ | |||||
| Document( | |||||
| page_content=docx2txt.process(self.file_path), | |||||
| metadata={"source": self.file_path}, | |||||
| ) | |||||
| ] | |||||
| @staticmethod | |||||
| def _is_valid_url(url: str) -> bool: | |||||
| """Check if the url is valid.""" | |||||
| parsed = urlparse(url) | |||||
| return bool(parsed.netloc) and bool(parsed.scheme) |
| from enum import Enum | |||||
| class IndexType(Enum): | |||||
| PARAGRAPH_INDEX = "text_model" | |||||
| QA_INDEX = "qa_model" | |||||
| PARENT_CHILD_INDEX = "parent_child_index" | |||||
| SUMMARY_INDEX = "summary_index" |
| """Abstract interface for document loader implementations.""" | |||||
| from abc import ABC, abstractmethod | |||||
| from typing import Optional | |||||
| from langchain.text_splitter import TextSplitter | |||||
| from core.model_manager import ModelInstance | |||||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||||
| from core.rag.models.document import Document | |||||
| from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter | |||||
| from models.dataset import Dataset, DatasetProcessRule | |||||
| class BaseIndexProcessor(ABC): | |||||
| """Interface for extract files. | |||||
| """ | |||||
| @abstractmethod | |||||
| def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def transform(self, documents: list[Document], **kwargs) -> list[Document]: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): | |||||
| raise NotImplementedError | |||||
| def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, | |||||
| score_threshold: float, reranking_model: dict) -> list[Document]: | |||||
| raise NotImplementedError | |||||
| def _get_splitter(self, processing_rule: dict, | |||||
| embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: | |||||
| """ | |||||
| Get the NodeParser object according to the processing rule. | |||||
| """ | |||||
| if processing_rule['mode'] == "custom": | |||||
| # The user-defined segmentation rule | |||||
| rules = processing_rule['rules'] | |||||
| segmentation = rules["segmentation"] | |||||
| if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000: | |||||
| raise ValueError("Custom segment length should be between 50 and 1000.") | |||||
| separator = segmentation["separator"] | |||||
| if separator: | |||||
| separator = separator.replace('\\n', '\n') | |||||
| character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( | |||||
| chunk_size=segmentation["max_tokens"], | |||||
| chunk_overlap=0, | |||||
| fixed_separator=separator, | |||||
| separators=["\n\n", "。", ".", " ", ""], | |||||
| embedding_model_instance=embedding_model_instance | |||||
| ) | |||||
| else: | |||||
| # Automatic segmentation | |||||
| character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( | |||||
| chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], | |||||
| chunk_overlap=0, | |||||
| separators=["\n\n", "。", ".", " ", ""], | |||||
| embedding_model_instance=embedding_model_instance | |||||
| ) | |||||
| return character_splitter |
| """Abstract interface for document loader implementations.""" | |||||
| from core.rag.index_processor.constant.index_type import IndexType | |||||
| from core.rag.index_processor.index_processor_base import BaseIndexProcessor | |||||
| from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor | |||||
| from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor | |||||
| class IndexProcessorFactory: | |||||
| """IndexProcessorInit. | |||||
| """ | |||||
| def __init__(self, index_type: str): | |||||
| self._index_type = index_type | |||||
| def init_index_processor(self) -> BaseIndexProcessor: | |||||
| """Init index processor.""" | |||||
| if not self._index_type: | |||||
| raise ValueError("Index type must be specified.") | |||||
| if self._index_type == IndexType.PARAGRAPH_INDEX.value: | |||||
| return ParagraphIndexProcessor() | |||||
| elif self._index_type == IndexType.QA_INDEX.value: | |||||
| return QAIndexProcessor() | |||||
| else: | |||||
| raise ValueError(f"Index type {self._index_type} is not supported.") |
| """Paragraph index processor.""" | |||||
| import uuid | |||||
| from typing import Optional | |||||
| from core.rag.cleaner.clean_processor import CleanProcessor | |||||
| from core.rag.datasource.keyword.keyword_factory import Keyword | |||||
| from core.rag.datasource.retrieval_service import RetrievalService | |||||
| from core.rag.datasource.vdb.vector_factory import Vector | |||||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||||
| from core.rag.extractor.extract_processor import ExtractProcessor | |||||
| from core.rag.index_processor.index_processor_base import BaseIndexProcessor | |||||
| from core.rag.models.document import Document | |||||
| from libs import helper | |||||
| from models.dataset import Dataset | |||||
| class ParagraphIndexProcessor(BaseIndexProcessor): | |||||
| def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: | |||||
| text_docs = ExtractProcessor.extract(extract_setting=extract_setting, | |||||
| is_automatic=kwargs.get('process_rule_mode') == "automatic") | |||||
| return text_docs | |||||
| def transform(self, documents: list[Document], **kwargs) -> list[Document]: | |||||
| # Split the text documents into nodes. | |||||
| splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), | |||||
| embedding_model_instance=kwargs.get('embedding_model_instance')) | |||||
| all_documents = [] | |||||
| for document in documents: | |||||
| # document clean | |||||
| document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) | |||||
| document.page_content = document_text | |||||
| # parse document to nodes | |||||
| document_nodes = splitter.split_documents([document]) | |||||
| split_documents = [] | |||||
| for document_node in document_nodes: | |||||
| if document_node.page_content.strip(): | |||||
| doc_id = str(uuid.uuid4()) | |||||
| hash = helper.generate_text_hash(document_node.page_content) | |||||
| document_node.metadata['doc_id'] = doc_id | |||||
| document_node.metadata['doc_hash'] = hash | |||||
| # delete Spliter character | |||||
| page_content = document_node.page_content | |||||
| if page_content.startswith(".") or page_content.startswith("。"): | |||||
| page_content = page_content[1:] | |||||
| else: | |||||
| page_content = page_content | |||||
| document_node.page_content = page_content | |||||
| split_documents.append(document_node) | |||||
| all_documents.extend(split_documents) | |||||
| return all_documents | |||||
| def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): | |||||
| if dataset.indexing_technique == 'high_quality': | |||||
| vector = Vector(dataset) | |||||
| vector.create(documents) | |||||
| if with_keywords: | |||||
| keyword = Keyword(dataset) | |||||
| keyword.create(documents) | |||||
| def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): | |||||
| if dataset.indexing_technique == 'high_quality': | |||||
| vector = Vector(dataset) | |||||
| if node_ids: | |||||
| vector.delete_by_ids(node_ids) | |||||
| else: | |||||
| vector.delete() | |||||
| if with_keywords: | |||||
| keyword = Keyword(dataset) | |||||
| if node_ids: | |||||
| keyword.delete_by_ids(node_ids) | |||||
| else: | |||||
| keyword.delete() | |||||
| def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, | |||||
| score_threshold: float, reranking_model: dict) -> list[Document]: | |||||
| # Set search parameters. | |||||
| results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query, | |||||
| top_k=top_k, score_threshold=score_threshold, | |||||
| reranking_model=reranking_model) | |||||
| # Organize results. | |||||
| docs = [] | |||||
| for result in results: | |||||
| metadata = result.metadata | |||||
| metadata['score'] = result.score | |||||
| if result.score > score_threshold: | |||||
| doc = Document(page_content=result.page_content, metadata=metadata) | |||||
| docs.append(doc) | |||||
| return docs |
| """Paragraph index processor.""" | |||||
| import logging | |||||
| import re | |||||
| import threading | |||||
| import uuid | |||||
| from typing import Optional | |||||
| import pandas as pd | |||||
| from flask import Flask, current_app | |||||
| from flask_login import current_user | |||||
| from werkzeug.datastructures import FileStorage | |||||
| from core.generator.llm_generator import LLMGenerator | |||||
| from core.rag.cleaner.clean_processor import CleanProcessor | |||||
| from core.rag.datasource.retrieval_service import RetrievalService | |||||
| from core.rag.datasource.vdb.vector_factory import Vector | |||||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||||
| from core.rag.extractor.extract_processor import ExtractProcessor | |||||
| from core.rag.index_processor.index_processor_base import BaseIndexProcessor | |||||
| from core.rag.models.document import Document | |||||
| from libs import helper | |||||
| from models.dataset import Dataset | |||||
| class QAIndexProcessor(BaseIndexProcessor): | |||||
| def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: | |||||
| text_docs = ExtractProcessor.extract(extract_setting=extract_setting, | |||||
| is_automatic=kwargs.get('process_rule_mode') == "automatic") | |||||
| return text_docs | |||||
| def transform(self, documents: list[Document], **kwargs) -> list[Document]: | |||||
| splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), | |||||
| embedding_model_instance=None) | |||||
| # Split the text documents into nodes. | |||||
| all_documents = [] | |||||
| all_qa_documents = [] | |||||
| for document in documents: | |||||
| # document clean | |||||
| document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) | |||||
| document.page_content = document_text | |||||
| # parse document to nodes | |||||
| document_nodes = splitter.split_documents([document]) | |||||
| split_documents = [] | |||||
| for document_node in document_nodes: | |||||
| if document_node.page_content.strip(): | |||||
| doc_id = str(uuid.uuid4()) | |||||
| hash = helper.generate_text_hash(document_node.page_content) | |||||
| document_node.metadata['doc_id'] = doc_id | |||||
| document_node.metadata['doc_hash'] = hash | |||||
| # delete Spliter character | |||||
| page_content = document_node.page_content | |||||
| if page_content.startswith(".") or page_content.startswith("。"): | |||||
| page_content = page_content[1:] | |||||
| else: | |||||
| page_content = page_content | |||||
| document_node.page_content = page_content | |||||
| split_documents.append(document_node) | |||||
| all_documents.extend(split_documents) | |||||
| for i in range(0, len(all_documents), 10): | |||||
| threads = [] | |||||
| sub_documents = all_documents[i:i + 10] | |||||
| for doc in sub_documents: | |||||
| document_format_thread = threading.Thread(target=self._format_qa_document, kwargs={ | |||||
| 'flask_app': current_app._get_current_object(), | |||||
| 'tenant_id': current_user.current_tenant.id, | |||||
| 'document_node': doc, | |||||
| 'all_qa_documents': all_qa_documents, | |||||
| 'document_language': kwargs.get('document_language', 'English')}) | |||||
| threads.append(document_format_thread) | |||||
| document_format_thread.start() | |||||
| for thread in threads: | |||||
| thread.join() | |||||
| return all_qa_documents | |||||
| def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: | |||||
| # check file type | |||||
| if not file.filename.endswith('.csv'): | |||||
| raise ValueError("Invalid file type. Only CSV files are allowed") | |||||
| try: | |||||
| # Skip the first row | |||||
| df = pd.read_csv(file) | |||||
| text_docs = [] | |||||
| for index, row in df.iterrows(): | |||||
| data = Document(page_content=row[0], metadata={'answer': row[1]}) | |||||
| text_docs.append(data) | |||||
| if len(text_docs) == 0: | |||||
| raise ValueError("The CSV file is empty.") | |||||
| except Exception as e: | |||||
| raise ValueError(str(e)) | |||||
| return text_docs | |||||
| def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): | |||||
| if dataset.indexing_technique == 'high_quality': | |||||
| vector = Vector(dataset) | |||||
| vector.create(documents) | |||||
| def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): | |||||
| vector = Vector(dataset) | |||||
| if node_ids: | |||||
| vector.delete_by_ids(node_ids) | |||||
| else: | |||||
| vector.delete() | |||||
| def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, | |||||
| score_threshold: float, reranking_model: dict): | |||||
| # Set search parameters. | |||||
| results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query, | |||||
| top_k=top_k, score_threshold=score_threshold, | |||||
| reranking_model=reranking_model) | |||||
| # Organize results. | |||||
| docs = [] | |||||
| for result in results: | |||||
| metadata = result.metadata | |||||
| metadata['score'] = result.score | |||||
| if result.score > score_threshold: | |||||
| doc = Document(page_content=result.page_content, metadata=metadata) | |||||
| docs.append(doc) | |||||
| return docs | |||||
| def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language): | |||||
| format_documents = [] | |||||
| if document_node.page_content is None or not document_node.page_content.strip(): | |||||
| return | |||||
| with flask_app.app_context(): | |||||
| try: | |||||
| # qa model document | |||||
| response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language) | |||||
| document_qa_list = self._format_split_text(response) | |||||
| qa_documents = [] | |||||
| for result in document_qa_list: | |||||
| qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy()) | |||||
| doc_id = str(uuid.uuid4()) | |||||
| hash = helper.generate_text_hash(result['question']) | |||||
| qa_document.metadata['answer'] = result['answer'] | |||||
| qa_document.metadata['doc_id'] = doc_id | |||||
| qa_document.metadata['doc_hash'] = hash | |||||
| qa_documents.append(qa_document) | |||||
| format_documents.extend(qa_documents) | |||||
| except Exception as e: | |||||
| logging.exception(e) | |||||
| all_qa_documents.extend(format_documents) | |||||
| def _format_split_text(self, text): | |||||
| regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" | |||||
| matches = re.findall(regex, text, re.UNICODE) | |||||
| return [ | |||||
| { | |||||
| "question": q, | |||||
| "answer": re.sub(r"\n\s*", "\n", a.strip()) | |||||
| } | |||||
| for q, a in matches if q and a | |||||
| ] |
| from typing import Optional | |||||
| from pydantic import BaseModel, Field | |||||
| class Document(BaseModel): | |||||
| """Class for storing a piece of text and associated metadata.""" | |||||
| page_content: str | |||||
| """Arbitrary metadata about the page content (e.g., source, relationships to other | |||||
| documents, etc.). | |||||
| """ | |||||
| metadata: Optional[dict] = Field(default_factory=dict) | |||||
| from regex import regex | from regex import regex | ||||
| from core.chain.llm_chain import LLMChain | from core.chain.llm_chain import LLMChain | ||||
| from core.data_loader import file_extractor | |||||
| from core.data_loader.file_extractor import FileExtractor | |||||
| from core.entities.application_entities import ModelConfigEntity | from core.entities.application_entities import ModelConfigEntity | ||||
| from core.rag.extractor import extract_processor | |||||
| from core.rag.extractor.extract_processor import ExtractProcessor | |||||
| FULL_TEMPLATE = """ | FULL_TEMPLATE = """ | ||||
| TITLE: {title} | TITLE: {title} | ||||
| headers = { | headers = { | ||||
| "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" | ||||
| } | } | ||||
| supported_content_types = file_extractor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] | |||||
| supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] | |||||
| head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10)) | head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10)) | ||||
| if main_content_type not in supported_content_types: | if main_content_type not in supported_content_types: | ||||
| return "Unsupported content-type [{}] of URL.".format(main_content_type) | return "Unsupported content-type [{}] of URL.".format(main_content_type) | ||||
| if main_content_type in file_extractor.SUPPORT_URL_CONTENT_TYPES: | |||||
| return FileExtractor.load_from_url(url, return_text=True) | |||||
| if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: | |||||
| return ExtractProcessor.load_from_url(url, return_text=True) | |||||
| response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30)) | response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30)) | ||||
| a = extract_using_readabilipy(response.text) | a = extract_using_readabilipy(response.text) |
| from pydantic import BaseModel, Field | from pydantic import BaseModel, Field | ||||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | ||||
| from core.embedding.cached_embedding import CacheEmbedding | |||||
| from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | |||||
| from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex | |||||
| from core.model_manager import ModelManager | from core.model_manager import ModelManager | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.rag.datasource.retrieval_service import RetrievalService | |||||
| from core.rerank.rerank import RerankRunner | from core.rerank.rerank import RerankRunner | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import Dataset, Document, DocumentSegment | from models.dataset import Dataset, Document, DocumentSegment | ||||
| from services.retrieval_service import RetrievalService | |||||
| default_retrieval_model = { | default_retrieval_model = { | ||||
| 'search_method': 'semantic_search', | 'search_method': 'semantic_search', | ||||
| if dataset.indexing_technique == "economy": | if dataset.indexing_technique == "economy": | ||||
| # use keyword table query | # use keyword table query | ||||
| kw_table_index = KeywordTableIndex( | |||||
| dataset=dataset, | |||||
| config=KeywordTableConfig( | |||||
| max_keywords_per_chunk=5 | |||||
| ) | |||||
| ) | |||||
| documents = kw_table_index.search(query, search_kwargs={'k': self.top_k}) | |||||
| documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], | |||||
| dataset_id=dataset.id, | |||||
| query=query, | |||||
| top_k=self.top_k | |||||
| ) | |||||
| if documents: | if documents: | ||||
| all_documents.extend(documents) | all_documents.extend(documents) | ||||
| else: | else: | ||||
| try: | |||||
| model_manager = ModelManager() | |||||
| embedding_model = model_manager.get_model_instance( | |||||
| tenant_id=dataset.tenant_id, | |||||
| provider=dataset.embedding_model_provider, | |||||
| model_type=ModelType.TEXT_EMBEDDING, | |||||
| model=dataset.embedding_model | |||||
| ) | |||||
| except LLMBadRequestError: | |||||
| return [] | |||||
| except ProviderTokenNotInitError: | |||||
| return [] | |||||
| embeddings = CacheEmbedding(embedding_model) | |||||
| documents = [] | |||||
| threads = [] | |||||
| if self.top_k > 0: | if self.top_k > 0: | ||||
| # retrieval_model source with semantic | |||||
| if retrieval_model['search_method'] == 'semantic_search' or retrieval_model[ | |||||
| 'search_method'] == 'hybrid_search': | |||||
| embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ | |||||
| 'flask_app': current_app._get_current_object(), | |||||
| 'dataset_id': str(dataset.id), | |||||
| 'query': query, | |||||
| 'top_k': self.top_k, | |||||
| 'score_threshold': self.score_threshold, | |||||
| 'reranking_model': None, | |||||
| 'all_documents': documents, | |||||
| 'search_method': 'hybrid_search', | |||||
| 'embeddings': embeddings | |||||
| }) | |||||
| threads.append(embedding_thread) | |||||
| embedding_thread.start() | |||||
| # retrieval_model source with full text | |||||
| if retrieval_model['search_method'] == 'full_text_search' or retrieval_model[ | |||||
| 'search_method'] == 'hybrid_search': | |||||
| full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, | |||||
| kwargs={ | |||||
| 'flask_app': current_app._get_current_object(), | |||||
| 'dataset_id': str(dataset.id), | |||||
| 'query': query, | |||||
| 'search_method': 'hybrid_search', | |||||
| 'embeddings': embeddings, | |||||
| 'score_threshold': retrieval_model[ | |||||
| 'score_threshold'] if retrieval_model[ | |||||
| 'score_threshold_enabled'] else None, | |||||
| 'top_k': self.top_k, | |||||
| 'reranking_model': retrieval_model[ | |||||
| 'reranking_model'] if retrieval_model[ | |||||
| 'reranking_enable'] else None, | |||||
| 'all_documents': documents | |||||
| }) | |||||
| threads.append(full_text_index_thread) | |||||
| full_text_index_thread.start() | |||||
| for thread in threads: | |||||
| thread.join() | |||||
| # retrieval source | |||||
| documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], | |||||
| dataset_id=dataset.id, | |||||
| query=query, | |||||
| top_k=self.top_k, | |||||
| score_threshold=retrieval_model['score_threshold'] | |||||
| if retrieval_model['score_threshold_enabled'] else None, | |||||
| reranking_model=retrieval_model['reranking_model'] | |||||
| if retrieval_model['reranking_enable'] else None | |||||
| ) | |||||
| all_documents.extend(documents) | all_documents.extend(documents) |
| import threading | |||||
| from typing import Optional | from typing import Optional | ||||
| from flask import current_app | |||||
| from langchain.tools import BaseTool | from langchain.tools import BaseTool | ||||
| from pydantic import BaseModel, Field | from pydantic import BaseModel, Field | ||||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | ||||
| from core.embedding.cached_embedding import CacheEmbedding | |||||
| from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex | |||||
| from core.model_manager import ModelManager | |||||
| from core.model_runtime.entities.model_entities import ModelType | |||||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||||
| from core.rerank.rerank import RerankRunner | |||||
| from core.rag.datasource.retrieval_service import RetrievalService | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import Dataset, Document, DocumentSegment | from models.dataset import Dataset, Document, DocumentSegment | ||||
| from services.retrieval_service import RetrievalService | |||||
| default_retrieval_model = { | default_retrieval_model = { | ||||
| 'search_method': 'semantic_search', | 'search_method': 'semantic_search', | ||||
| retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | ||||
| if dataset.indexing_technique == "economy": | if dataset.indexing_technique == "economy": | ||||
| # use keyword table query | # use keyword table query | ||||
| kw_table_index = KeywordTableIndex( | |||||
| dataset=dataset, | |||||
| config=KeywordTableConfig( | |||||
| max_keywords_per_chunk=5 | |||||
| ) | |||||
| ) | |||||
| documents = kw_table_index.search(query, search_kwargs={'k': self.top_k}) | |||||
| documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], | |||||
| dataset_id=dataset.id, | |||||
| query=query, | |||||
| top_k=self.top_k | |||||
| ) | |||||
| return str("\n".join([document.page_content for document in documents])) | return str("\n".join([document.page_content for document in documents])) | ||||
| else: | else: | ||||
| # get embedding model instance | |||||
| try: | |||||
| model_manager = ModelManager() | |||||
| embedding_model = model_manager.get_model_instance( | |||||
| tenant_id=dataset.tenant_id, | |||||
| provider=dataset.embedding_model_provider, | |||||
| model_type=ModelType.TEXT_EMBEDDING, | |||||
| model=dataset.embedding_model | |||||
| ) | |||||
| except InvokeAuthorizationError: | |||||
| return '' | |||||
| embeddings = CacheEmbedding(embedding_model) | |||||
| documents = [] | |||||
| threads = [] | |||||
| if self.top_k > 0: | if self.top_k > 0: | ||||
| # retrieval source with semantic | |||||
| if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': | |||||
| embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ | |||||
| 'flask_app': current_app._get_current_object(), | |||||
| 'dataset_id': str(dataset.id), | |||||
| 'query': query, | |||||
| 'top_k': self.top_k, | |||||
| 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ | |||||
| 'score_threshold_enabled'] else None, | |||||
| 'reranking_model': retrieval_model['reranking_model'] if retrieval_model[ | |||||
| 'reranking_enable'] else None, | |||||
| 'all_documents': documents, | |||||
| 'search_method': retrieval_model['search_method'], | |||||
| 'embeddings': embeddings | |||||
| }) | |||||
| threads.append(embedding_thread) | |||||
| embedding_thread.start() | |||||
| # retrieval_model source with full text | |||||
| if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': | |||||
| full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ | |||||
| 'flask_app': current_app._get_current_object(), | |||||
| 'dataset_id': str(dataset.id), | |||||
| 'query': query, | |||||
| 'search_method': retrieval_model['search_method'], | |||||
| 'embeddings': embeddings, | |||||
| 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ | |||||
| 'score_threshold_enabled'] else None, | |||||
| 'top_k': self.top_k, | |||||
| 'reranking_model': retrieval_model['reranking_model'] if retrieval_model[ | |||||
| 'reranking_enable'] else None, | |||||
| 'all_documents': documents | |||||
| }) | |||||
| threads.append(full_text_index_thread) | |||||
| full_text_index_thread.start() | |||||
| for thread in threads: | |||||
| thread.join() | |||||
| # hybrid search: rerank after all documents have been searched | |||||
| if retrieval_model['search_method'] == 'hybrid_search': | |||||
| # get rerank model instance | |||||
| try: | |||||
| model_manager = ModelManager() | |||||
| rerank_model_instance = model_manager.get_model_instance( | |||||
| tenant_id=dataset.tenant_id, | |||||
| provider=retrieval_model['reranking_model']['reranking_provider_name'], | |||||
| model_type=ModelType.RERANK, | |||||
| model=retrieval_model['reranking_model']['reranking_model_name'] | |||||
| ) | |||||
| except InvokeAuthorizationError: | |||||
| return '' | |||||
| rerank_runner = RerankRunner(rerank_model_instance) | |||||
| documents = rerank_runner.run( | |||||
| query=query, | |||||
| documents=documents, | |||||
| score_threshold=retrieval_model['score_threshold'] if retrieval_model[ | |||||
| 'score_threshold_enabled'] else None, | |||||
| top_n=self.top_k | |||||
| ) | |||||
| # retrieval source | |||||
| documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], | |||||
| dataset_id=dataset.id, | |||||
| query=query, | |||||
| top_k=self.top_k, | |||||
| score_threshold=retrieval_model['score_threshold'] | |||||
| if retrieval_model['score_threshold_enabled'] else None, | |||||
| reranking_model=retrieval_model['reranking_model'] | |||||
| if retrieval_model['reranking_enable'] else None | |||||
| ) | |||||
| else: | else: | ||||
| documents = [] | documents = [] | ||||
| return str("\n".join(document_context_list)) | return str("\n".join(document_context_list)) | ||||
| async def _arun(self, tool_input: str) -> str: | async def _arun(self, tool_input: str) -> str: | ||||
| raise NotImplementedError() | |||||
| raise NotImplementedError() |
| from regex import regex | from regex import regex | ||||
| from core.chain.llm_chain import LLMChain | from core.chain.llm_chain import LLMChain | ||||
| from core.data_loader import file_extractor | |||||
| from core.data_loader.file_extractor import FileExtractor | |||||
| from core.entities.application_entities import ModelConfigEntity | from core.entities.application_entities import ModelConfigEntity | ||||
| from core.rag.extractor import extract_processor | |||||
| from core.rag.extractor.extract_processor import ExtractProcessor | |||||
| FULL_TEMPLATE = """ | FULL_TEMPLATE = """ | ||||
| TITLE: {title} | TITLE: {title} | ||||
| if user_agent: | if user_agent: | ||||
| headers["User-Agent"] = user_agent | headers["User-Agent"] = user_agent | ||||
| supported_content_types = file_extractor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] | |||||
| supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] | |||||
| head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10)) | head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10)) | ||||
| if main_content_type not in supported_content_types: | if main_content_type not in supported_content_types: | ||||
| return "Unsupported content-type [{}] of URL.".format(main_content_type) | return "Unsupported content-type [{}] of URL.".format(main_content_type) | ||||
| if main_content_type in file_extractor.SUPPORT_URL_CONTENT_TYPES: | |||||
| return FileExtractor.load_from_url(url, return_text=True) | |||||
| if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: | |||||
| return ExtractProcessor.load_from_url(url, return_text=True) | |||||
| response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30)) | response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30)) | ||||
| a = extract_using_readabilipy(response.text) | a = extract_using_readabilipy(response.text) |
| from core.vector_store.vector.milvus import Milvus | |||||
| class MilvusVectorStore(Milvus): | |||||
| def del_texts(self, where_filter: dict): | |||||
| if not where_filter: | |||||
| raise ValueError('where_filter must not be empty') | |||||
| self.col.delete(where_filter.get('filter')) | |||||
| def del_text(self, uuid: str) -> None: | |||||
| expr = f"id == {uuid}" | |||||
| self.col.delete(expr) | |||||
| def text_exists(self, uuid: str) -> bool: | |||||
| result = self.col.query( | |||||
| expr=f'metadata["doc_id"] == "{uuid}"', | |||||
| output_fields=["id"] | |||||
| ) | |||||
| return len(result) > 0 | |||||
| def get_ids_by_document_id(self, document_id: str): | |||||
| result = self.col.query( | |||||
| expr=f'metadata["document_id"] == "{document_id}"', | |||||
| output_fields=["id"] | |||||
| ) | |||||
| if result: | |||||
| return [item["id"] for item in result] | |||||
| else: | |||||
| return None | |||||
| def get_ids_by_metadata_field(self, key: str, value: str): | |||||
| result = self.col.query( | |||||
| expr=f'metadata["{key}"] == "{value}"', | |||||
| output_fields=["id"] | |||||
| ) | |||||
| if result: | |||||
| return [item["id"] for item in result] | |||||
| else: | |||||
| return None | |||||
| def get_ids_by_doc_ids(self, doc_ids: list): | |||||
| result = self.col.query( | |||||
| expr=f'metadata["doc_id"] in {doc_ids}', | |||||
| output_fields=["id"] | |||||
| ) | |||||
| if result: | |||||
| return [item["id"] for item in result] | |||||
| else: | |||||
| return None | |||||
| def delete(self): | |||||
| from pymilvus import utility | |||||
| utility.drop_collection(self.collection_name, None, self.alias) | |||||
| from typing import Any, cast | |||||
| from langchain.schema import Document | |||||
| from qdrant_client.http.models import Filter, FilterSelector, PointIdsList | |||||
| from qdrant_client.local.qdrant_local import QdrantLocal | |||||
| from core.vector_store.vector.qdrant import Qdrant | |||||
| class QdrantVectorStore(Qdrant): | |||||
| def del_texts(self, filter: Filter): | |||||
| if not filter: | |||||
| raise ValueError('filter must not be empty') | |||||
| self._reload_if_needed() | |||||
| self.client.delete( | |||||
| collection_name=self.collection_name, | |||||
| points_selector=FilterSelector( | |||||
| filter=filter | |||||
| ), | |||||
| ) | |||||
| def del_text(self, uuid: str) -> None: | |||||
| self._reload_if_needed() | |||||
| self.client.delete( | |||||
| collection_name=self.collection_name, | |||||
| points_selector=PointIdsList( | |||||
| points=[uuid], | |||||
| ), | |||||
| ) | |||||
| def text_exists(self, uuid: str) -> bool: | |||||
| self._reload_if_needed() | |||||
| response = self.client.retrieve( | |||||
| collection_name=self.collection_name, | |||||
| ids=[uuid] | |||||
| ) | |||||
| return len(response) > 0 | |||||
| def delete(self): | |||||
| self._reload_if_needed() | |||||
| self.client.delete_collection(collection_name=self.collection_name) | |||||
| def delete_group(self): | |||||
| self._reload_if_needed() | |||||
| self.client.delete_collection(collection_name=self.collection_name) | |||||
| @classmethod | |||||
| def _document_from_scored_point( | |||||
| cls, | |||||
| scored_point: Any, | |||||
| content_payload_key: str, | |||||
| metadata_payload_key: str, | |||||
| ) -> Document: | |||||
| if scored_point.payload.get('doc_id'): | |||||
| return Document( | |||||
| page_content=scored_point.payload.get(content_payload_key), | |||||
| metadata={'doc_id': scored_point.id} | |||||
| ) | |||||
| return Document( | |||||
| page_content=scored_point.payload.get(content_payload_key), | |||||
| metadata=scored_point.payload.get(metadata_payload_key) or {}, | |||||
| ) | |||||
| def _reload_if_needed(self): | |||||
| if isinstance(self.client, QdrantLocal): | |||||
| self.client = cast(QdrantLocal, self.client) | |||||
| self.client._load() | |||||
| """Wrapper around the Milvus vector database.""" | |||||
| from __future__ import annotations | |||||
| import logging | |||||
| from collections.abc import Iterable, Sequence | |||||
| from typing import Any, Optional, Union | |||||
| from uuid import uuid4 | |||||
| import numpy as np | |||||
| from langchain.docstore.document import Document | |||||
| from langchain.embeddings.base import Embeddings | |||||
| from langchain.vectorstores.base import VectorStore | |||||
| from langchain.vectorstores.utils import maximal_marginal_relevance | |||||
| logger = logging.getLogger(__name__) | |||||
| DEFAULT_MILVUS_CONNECTION = { | |||||
| "host": "localhost", | |||||
| "port": "19530", | |||||
| "user": "", | |||||
| "password": "", | |||||
| "secure": False, | |||||
| } | |||||
| class Milvus(VectorStore): | |||||
| """Initialize wrapper around the milvus vector database. | |||||
| In order to use this you need to have `pymilvus` installed and a | |||||
| running Milvus | |||||
| See the following documentation for how to run a Milvus instance: | |||||
| https://milvus.io/docs/install_standalone-docker.md | |||||
| If looking for a hosted Milvus, take a look at this documentation: | |||||
| https://zilliz.com/cloud and make use of the Zilliz vectorstore found in | |||||
| this project, | |||||
| IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA. | |||||
| Args: | |||||
| embedding_function (Embeddings): Function used to embed the text. | |||||
| collection_name (str): Which Milvus collection to use. Defaults to | |||||
| "LangChainCollection". | |||||
| connection_args (Optional[dict[str, any]]): The connection args used for | |||||
| this class comes in the form of a dict. | |||||
| consistency_level (str): The consistency level to use for a collection. | |||||
| Defaults to "Session". | |||||
| index_params (Optional[dict]): Which index params to use. Defaults to | |||||
| HNSW/AUTOINDEX depending on service. | |||||
| search_params (Optional[dict]): Which search params to use. Defaults to | |||||
| default of index. | |||||
| drop_old (Optional[bool]): Whether to drop the current collection. Defaults | |||||
| to False. | |||||
| The connection args used for this class comes in the form of a dict, | |||||
| here are a few of the options: | |||||
| address (str): The actual address of Milvus | |||||
| instance. Example address: "localhost:19530" | |||||
| uri (str): The uri of Milvus instance. Example uri: | |||||
| "http://randomwebsite:19530", | |||||
| "tcp:foobarsite:19530", | |||||
| "https://ok.s3.south.com:19530". | |||||
| host (str): The host of Milvus instance. Default at "localhost", | |||||
| PyMilvus will fill in the default host if only port is provided. | |||||
| port (str/int): The port of Milvus instance. Default at 19530, PyMilvus | |||||
| will fill in the default port if only host is provided. | |||||
| user (str): Use which user to connect to Milvus instance. If user and | |||||
| password are provided, we will add related header in every RPC call. | |||||
| password (str): Required when user is provided. The password | |||||
| corresponding to the user. | |||||
| secure (bool): Default is false. If set to true, tls will be enabled. | |||||
| client_key_path (str): If use tls two-way authentication, need to | |||||
| write the client.key path. | |||||
| client_pem_path (str): If use tls two-way authentication, need to | |||||
| write the client.pem path. | |||||
| ca_pem_path (str): If use tls two-way authentication, need to write | |||||
| the ca.pem path. | |||||
| server_pem_path (str): If use tls one-way authentication, need to | |||||
| write the server.pem path. | |||||
| server_name (str): If use tls, need to write the common name. | |||||
| Example: | |||||
| .. code-block:: python | |||||
| from langchain import Milvus | |||||
| from langchain.embeddings import OpenAIEmbeddings | |||||
| embedding = OpenAIEmbeddings() | |||||
| # Connect to a milvus instance on localhost | |||||
| milvus_store = Milvus( | |||||
| embedding_function = Embeddings, | |||||
| collection_name = "LangChainCollection", | |||||
| drop_old = True, | |||||
| ) | |||||
| Raises: | |||||
| ValueError: If the pymilvus python package is not installed. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| embedding_function: Embeddings, | |||||
| collection_name: str = "LangChainCollection", | |||||
| connection_args: Optional[dict[str, Any]] = None, | |||||
| consistency_level: str = "Session", | |||||
| index_params: Optional[dict] = None, | |||||
| search_params: Optional[dict] = None, | |||||
| drop_old: Optional[bool] = False, | |||||
| ): | |||||
| """Initialize the Milvus vector store.""" | |||||
| try: | |||||
| from pymilvus import Collection, utility | |||||
| except ImportError: | |||||
| raise ValueError( | |||||
| "Could not import pymilvus python package. " | |||||
| "Please install it with `pip install pymilvus`." | |||||
| ) | |||||
| # Default search params when one is not provided. | |||||
| self.default_search_params = { | |||||
| "IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, | |||||
| "IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}}, | |||||
| "IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}}, | |||||
| "HNSW": {"metric_type": "L2", "params": {"ef": 10}}, | |||||
| "RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}}, | |||||
| "RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}}, | |||||
| "RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}}, | |||||
| "IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}}, | |||||
| "ANNOY": {"metric_type": "L2", "params": {"search_k": 10}}, | |||||
| "AUTOINDEX": {"metric_type": "L2", "params": {}}, | |||||
| } | |||||
| self.embedding_func = embedding_function | |||||
| self.collection_name = collection_name | |||||
| self.index_params = index_params | |||||
| self.search_params = search_params | |||||
| self.consistency_level = consistency_level | |||||
| # In order for a collection to be compatible, pk needs to be auto'id and int | |||||
| self._primary_field = "id" | |||||
| # In order for compatibility, the text field will need to be called "text" | |||||
| self._text_field = "page_content" | |||||
| # In order for compatibility, the vector field needs to be called "vector" | |||||
| self._vector_field = "vectors" | |||||
| # In order for compatibility, the metadata field will need to be called "metadata" | |||||
| self._metadata_field = "metadata" | |||||
| self.fields: list[str] = [] | |||||
| # Create the connection to the server | |||||
| if connection_args is None: | |||||
| connection_args = DEFAULT_MILVUS_CONNECTION | |||||
| self.alias = self._create_connection_alias(connection_args) | |||||
| self.col: Optional[Collection] = None | |||||
| # Grab the existing collection if it exists | |||||
| if utility.has_collection(self.collection_name, using=self.alias): | |||||
| self.col = Collection( | |||||
| self.collection_name, | |||||
| using=self.alias, | |||||
| ) | |||||
| # If need to drop old, drop it | |||||
| if drop_old and isinstance(self.col, Collection): | |||||
| self.col.drop() | |||||
| self.col = None | |||||
| # Initialize the vector store | |||||
| self._init() | |||||
| @property | |||||
| def embeddings(self) -> Embeddings: | |||||
| return self.embedding_func | |||||
| def _create_connection_alias(self, connection_args: dict) -> str: | |||||
| """Create the connection to the Milvus server.""" | |||||
| from pymilvus import MilvusException, connections | |||||
| # Grab the connection arguments that are used for checking existing connection | |||||
| host: str = connection_args.get("host", None) | |||||
| port: Union[str, int] = connection_args.get("port", None) | |||||
| address: str = connection_args.get("address", None) | |||||
| uri: str = connection_args.get("uri", None) | |||||
| user = connection_args.get("user", None) | |||||
| # Order of use is host/port, uri, address | |||||
| if host is not None and port is not None: | |||||
| given_address = str(host) + ":" + str(port) | |||||
| elif uri is not None: | |||||
| given_address = uri.split("https://")[1] | |||||
| elif address is not None: | |||||
| given_address = address | |||||
| else: | |||||
| given_address = None | |||||
| logger.debug("Missing standard address type for reuse atttempt") | |||||
| # User defaults to empty string when getting connection info | |||||
| if user is not None: | |||||
| tmp_user = user | |||||
| else: | |||||
| tmp_user = "" | |||||
| # If a valid address was given, then check if a connection exists | |||||
| if given_address is not None: | |||||
| for con in connections.list_connections(): | |||||
| addr = connections.get_connection_addr(con[0]) | |||||
| if ( | |||||
| con[1] | |||||
| and ("address" in addr) | |||||
| and (addr["address"] == given_address) | |||||
| and ("user" in addr) | |||||
| and (addr["user"] == tmp_user) | |||||
| ): | |||||
| logger.debug("Using previous connection: %s", con[0]) | |||||
| return con[0] | |||||
| # Generate a new connection if one doesn't exist | |||||
| alias = uuid4().hex | |||||
| try: | |||||
| connections.connect(alias=alias, **connection_args) | |||||
| logger.debug("Created new connection using: %s", alias) | |||||
| return alias | |||||
| except MilvusException as e: | |||||
| logger.error("Failed to create new connection using: %s", alias) | |||||
| raise e | |||||
| def _init( | |||||
| self, embeddings: Optional[list] = None, metadatas: Optional[list[dict]] = None | |||||
| ) -> None: | |||||
| if embeddings is not None: | |||||
| self._create_collection(embeddings, metadatas) | |||||
| self._extract_fields() | |||||
| self._create_index() | |||||
| self._create_search_params() | |||||
| self._load() | |||||
| def _create_collection( | |||||
| self, embeddings: list, metadatas: Optional[list[dict]] = None | |||||
| ) -> None: | |||||
| from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, MilvusException | |||||
| from pymilvus.orm.types import infer_dtype_bydata | |||||
| # Determine embedding dim | |||||
| dim = len(embeddings[0]) | |||||
| fields = [] | |||||
| # Determine metadata schema | |||||
| # if metadatas: | |||||
| # # Create FieldSchema for each entry in metadata. | |||||
| # for key, value in metadatas[0].items(): | |||||
| # # Infer the corresponding datatype of the metadata | |||||
| # dtype = infer_dtype_bydata(value) | |||||
| # # Datatype isn't compatible | |||||
| # if dtype == DataType.UNKNOWN or dtype == DataType.NONE: | |||||
| # logger.error( | |||||
| # "Failure to create collection, unrecognized dtype for key: %s", | |||||
| # key, | |||||
| # ) | |||||
| # raise ValueError(f"Unrecognized datatype for {key}.") | |||||
| # # Dataype is a string/varchar equivalent | |||||
| # elif dtype == DataType.VARCHAR: | |||||
| # fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535)) | |||||
| # else: | |||||
| # fields.append(FieldSchema(key, dtype)) | |||||
| if metadatas: | |||||
| fields.append(FieldSchema(self._metadata_field, DataType.JSON, max_length=65_535)) | |||||
| # Create the text field | |||||
| fields.append( | |||||
| FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535) | |||||
| ) | |||||
| # Create the primary key field | |||||
| fields.append( | |||||
| FieldSchema( | |||||
| self._primary_field, DataType.INT64, is_primary=True, auto_id=True | |||||
| ) | |||||
| ) | |||||
| # Create the vector field, supports binary or float vectors | |||||
| fields.append( | |||||
| FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim) | |||||
| ) | |||||
| # Create the schema for the collection | |||||
| schema = CollectionSchema(fields) | |||||
| # Create the collection | |||||
| try: | |||||
| self.col = Collection( | |||||
| name=self.collection_name, | |||||
| schema=schema, | |||||
| consistency_level=self.consistency_level, | |||||
| using=self.alias, | |||||
| ) | |||||
| except MilvusException as e: | |||||
| logger.error( | |||||
| "Failed to create collection: %s error: %s", self.collection_name, e | |||||
| ) | |||||
| raise e | |||||
| def _extract_fields(self) -> None: | |||||
| """Grab the existing fields from the Collection""" | |||||
| from pymilvus import Collection | |||||
| if isinstance(self.col, Collection): | |||||
| schema = self.col.schema | |||||
| for x in schema.fields: | |||||
| self.fields.append(x.name) | |||||
| # Since primary field is auto-id, no need to track it | |||||
| self.fields.remove(self._primary_field) | |||||
| def _get_index(self) -> Optional[dict[str, Any]]: | |||||
| """Return the vector index information if it exists""" | |||||
| from pymilvus import Collection | |||||
| if isinstance(self.col, Collection): | |||||
| for x in self.col.indexes: | |||||
| if x.field_name == self._vector_field: | |||||
| return x.to_dict() | |||||
| return None | |||||
| def _create_index(self) -> None: | |||||
| """Create a index on the collection""" | |||||
| from pymilvus import Collection, MilvusException | |||||
| if isinstance(self.col, Collection) and self._get_index() is None: | |||||
| try: | |||||
| # If no index params, use a default HNSW based one | |||||
| if self.index_params is None: | |||||
| self.index_params = { | |||||
| "metric_type": "IP", | |||||
| "index_type": "HNSW", | |||||
| "params": {"M": 8, "efConstruction": 64}, | |||||
| } | |||||
| try: | |||||
| self.col.create_index( | |||||
| self._vector_field, | |||||
| index_params=self.index_params, | |||||
| using=self.alias, | |||||
| ) | |||||
| # If default did not work, most likely on Zilliz Cloud | |||||
| except MilvusException: | |||||
| # Use AUTOINDEX based index | |||||
| self.index_params = { | |||||
| "metric_type": "L2", | |||||
| "index_type": "AUTOINDEX", | |||||
| "params": {}, | |||||
| } | |||||
| self.col.create_index( | |||||
| self._vector_field, | |||||
| index_params=self.index_params, | |||||
| using=self.alias, | |||||
| ) | |||||
| logger.debug( | |||||
| "Successfully created an index on collection: %s", | |||||
| self.collection_name, | |||||
| ) | |||||
| except MilvusException as e: | |||||
| logger.error( | |||||
| "Failed to create an index on collection: %s", self.collection_name | |||||
| ) | |||||
| raise e | |||||
| def _create_search_params(self) -> None: | |||||
| """Generate search params based on the current index type""" | |||||
| from pymilvus import Collection | |||||
| if isinstance(self.col, Collection) and self.search_params is None: | |||||
| index = self._get_index() | |||||
| if index is not None: | |||||
| index_type: str = index["index_param"]["index_type"] | |||||
| metric_type: str = index["index_param"]["metric_type"] | |||||
| self.search_params = self.default_search_params[index_type] | |||||
| self.search_params["metric_type"] = metric_type | |||||
| def _load(self) -> None: | |||||
| """Load the collection if available.""" | |||||
| from pymilvus import Collection | |||||
| if isinstance(self.col, Collection) and self._get_index() is not None: | |||||
| self.col.load() | |||||
| def add_texts( | |||||
| self, | |||||
| texts: Iterable[str], | |||||
| metadatas: Optional[list[dict]] = None, | |||||
| timeout: Optional[int] = None, | |||||
| batch_size: int = 1000, | |||||
| **kwargs: Any, | |||||
| ) -> list[str]: | |||||
| """Insert text data into Milvus. | |||||
| Inserting data when the collection has not be made yet will result | |||||
| in creating a new Collection. The data of the first entity decides | |||||
| the schema of the new collection, the dim is extracted from the first | |||||
| embedding and the columns are decided by the first metadata dict. | |||||
| Metada keys will need to be present for all inserted values. At | |||||
| the moment there is no None equivalent in Milvus. | |||||
| Args: | |||||
| texts (Iterable[str]): The texts to embed, it is assumed | |||||
| that they all fit in memory. | |||||
| metadatas (Optional[List[dict]]): Metadata dicts attached to each of | |||||
| the texts. Defaults to None. | |||||
| timeout (Optional[int]): Timeout for each batch insert. Defaults | |||||
| to None. | |||||
| batch_size (int, optional): Batch size to use for insertion. | |||||
| Defaults to 1000. | |||||
| Raises: | |||||
| MilvusException: Failure to add texts | |||||
| Returns: | |||||
| List[str]: The resulting keys for each inserted element. | |||||
| """ | |||||
| from pymilvus import Collection, MilvusException | |||||
| texts = list(texts) | |||||
| try: | |||||
| embeddings = self.embedding_func.embed_documents(texts) | |||||
| except NotImplementedError: | |||||
| embeddings = [self.embedding_func.embed_query(x) for x in texts] | |||||
| if len(embeddings) == 0: | |||||
| logger.debug("Nothing to insert, skipping.") | |||||
| return [] | |||||
| # If the collection hasn't been initialized yet, perform all steps to do so | |||||
| if not isinstance(self.col, Collection): | |||||
| self._init(embeddings, metadatas) | |||||
| # Dict to hold all insert columns | |||||
| insert_dict: dict[str, list] = { | |||||
| self._text_field: texts, | |||||
| self._vector_field: embeddings, | |||||
| } | |||||
| # Collect the metadata into the insert dict. | |||||
| # if metadatas is not None: | |||||
| # for d in metadatas: | |||||
| # for key, value in d.items(): | |||||
| # if key in self.fields: | |||||
| # insert_dict.setdefault(key, []).append(value) | |||||
| if metadatas is not None: | |||||
| for d in metadatas: | |||||
| insert_dict.setdefault(self._metadata_field, []).append(d) | |||||
| # Total insert count | |||||
| vectors: list = insert_dict[self._vector_field] | |||||
| total_count = len(vectors) | |||||
| pks: list[str] = [] | |||||
| assert isinstance(self.col, Collection) | |||||
| for i in range(0, total_count, batch_size): | |||||
| # Grab end index | |||||
| end = min(i + batch_size, total_count) | |||||
| # Convert dict to list of lists batch for insertion | |||||
| insert_list = [insert_dict[x][i:end] for x in self.fields] | |||||
| # Insert into the collection. | |||||
| try: | |||||
| res: Collection | |||||
| res = self.col.insert(insert_list, timeout=timeout, **kwargs) | |||||
| pks.extend(res.primary_keys) | |||||
| except MilvusException as e: | |||||
| logger.error( | |||||
| "Failed to insert batch starting at entity: %s/%s", i, total_count | |||||
| ) | |||||
| raise e | |||||
| return pks | |||||
| def similarity_search( | |||||
| self, | |||||
| query: str, | |||||
| k: int = 4, | |||||
| param: Optional[dict] = None, | |||||
| expr: Optional[str] = None, | |||||
| timeout: Optional[int] = None, | |||||
| **kwargs: Any, | |||||
| ) -> list[Document]: | |||||
| """Perform a similarity search against the query string. | |||||
| Args: | |||||
| query (str): The text to search. | |||||
| k (int, optional): How many results to return. Defaults to 4. | |||||
| param (dict, optional): The search params for the index type. | |||||
| Defaults to None. | |||||
| expr (str, optional): Filtering expression. Defaults to None. | |||||
| timeout (int, optional): How long to wait before timeout error. | |||||
| Defaults to None. | |||||
| kwargs: Collection.search() keyword arguments. | |||||
| Returns: | |||||
| List[Document]: Document results for search. | |||||
| """ | |||||
| if self.col is None: | |||||
| logger.debug("No existing collection to search.") | |||||
| return [] | |||||
| res = self.similarity_search_with_score( | |||||
| query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs | |||||
| ) | |||||
| return [doc for doc, _ in res] | |||||
| def similarity_search_by_vector( | |||||
| self, | |||||
| embedding: list[float], | |||||
| k: int = 4, | |||||
| param: Optional[dict] = None, | |||||
| expr: Optional[str] = None, | |||||
| timeout: Optional[int] = None, | |||||
| **kwargs: Any, | |||||
| ) -> list[Document]: | |||||
| """Perform a similarity search against the query string. | |||||
| Args: | |||||
| embedding (List[float]): The embedding vector to search. | |||||
| k (int, optional): How many results to return. Defaults to 4. | |||||
| param (dict, optional): The search params for the index type. | |||||
| Defaults to None. | |||||
| expr (str, optional): Filtering expression. Defaults to None. | |||||
| timeout (int, optional): How long to wait before timeout error. | |||||
| Defaults to None. | |||||
| kwargs: Collection.search() keyword arguments. | |||||
| Returns: | |||||
| List[Document]: Document results for search. | |||||
| """ | |||||
| if self.col is None: | |||||
| logger.debug("No existing collection to search.") | |||||
| return [] | |||||
| res = self.similarity_search_with_score_by_vector( | |||||
| embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs | |||||
| ) | |||||
| return [doc for doc, _ in res] | |||||
| def similarity_search_with_score( | |||||
| self, | |||||
| query: str, | |||||
| k: int = 4, | |||||
| param: Optional[dict] = None, | |||||
| expr: Optional[str] = None, | |||||
| timeout: Optional[int] = None, | |||||
| **kwargs: Any, | |||||
| ) -> list[tuple[Document, float]]: | |||||
| """Perform a search on a query string and return results with score. | |||||
| For more information about the search parameters, take a look at the pymilvus | |||||
| documentation found here: | |||||
| https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md | |||||
| Args: | |||||
| query (str): The text being searched. | |||||
| k (int, optional): The amount of results to return. Defaults to 4. | |||||
| param (dict): The search params for the specified index. | |||||
| Defaults to None. | |||||
| expr (str, optional): Filtering expression. Defaults to None. | |||||
| timeout (int, optional): How long to wait before timeout error. | |||||
| Defaults to None. | |||||
| kwargs: Collection.search() keyword arguments. | |||||
| Returns: | |||||
| List[float], List[Tuple[Document, any, any]]: | |||||
| """ | |||||
| if self.col is None: | |||||
| logger.debug("No existing collection to search.") | |||||
| return [] | |||||
| # Embed the query text. | |||||
| embedding = self.embedding_func.embed_query(query) | |||||
| res = self.similarity_search_with_score_by_vector( | |||||
| embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs | |||||
| ) | |||||
| return res | |||||
| def _similarity_search_with_relevance_scores( | |||||
| self, | |||||
| query: str, | |||||
| k: int = 4, | |||||
| **kwargs: Any, | |||||
| ) -> list[tuple[Document, float]]: | |||||
| """Return docs and relevance scores in the range [0, 1]. | |||||
| 0 is dissimilar, 1 is most similar. | |||||
| Args: | |||||
| query: input text | |||||
| k: Number of Documents to return. Defaults to 4. | |||||
| **kwargs: kwargs to be passed to similarity search. Should include: | |||||
| score_threshold: Optional, a floating point value between 0 to 1 to | |||||
| filter the resulting set of retrieved docs | |||||
| Returns: | |||||
| List of Tuples of (doc, similarity_score) | |||||
| """ | |||||
| return self.similarity_search_with_score(query, k, **kwargs) | |||||
| def similarity_search_with_score_by_vector( | |||||
| self, | |||||
| embedding: list[float], | |||||
| k: int = 4, | |||||
| param: Optional[dict] = None, | |||||
| expr: Optional[str] = None, | |||||
| timeout: Optional[int] = None, | |||||
| **kwargs: Any, | |||||
| ) -> list[tuple[Document, float]]: | |||||
| """Perform a search on a query string and return results with score. | |||||
| For more information about the search parameters, take a look at the pymilvus | |||||
| documentation found here: | |||||
| https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md | |||||
| Args: | |||||
| embedding (List[float]): The embedding vector being searched. | |||||
| k (int, optional): The amount of results to return. Defaults to 4. | |||||
| param (dict): The search params for the specified index. | |||||
| Defaults to None. | |||||
| expr (str, optional): Filtering expression. Defaults to None. | |||||
| timeout (int, optional): How long to wait before timeout error. | |||||
| Defaults to None. | |||||
| kwargs: Collection.search() keyword arguments. | |||||
| Returns: | |||||
| List[Tuple[Document, float]]: Result doc and score. | |||||
| """ | |||||
| if self.col is None: | |||||
| logger.debug("No existing collection to search.") | |||||
| return [] | |||||
| if param is None: | |||||
| param = self.search_params | |||||
| # Determine result metadata fields. | |||||
| output_fields = self.fields[:] | |||||
| output_fields.remove(self._vector_field) | |||||
| # Perform the search. | |||||
| res = self.col.search( | |||||
| data=[embedding], | |||||
| anns_field=self._vector_field, | |||||
| param=param, | |||||
| limit=k, | |||||
| expr=expr, | |||||
| output_fields=output_fields, | |||||
| timeout=timeout, | |||||
| **kwargs, | |||||
| ) | |||||
| # Organize results. | |||||
| ret = [] | |||||
| for result in res[0]: | |||||
| meta = {x: result.entity.get(x) for x in output_fields} | |||||
| doc = Document(page_content=meta.pop(self._text_field), metadata=meta.get('metadata')) | |||||
| pair = (doc, result.score) | |||||
| ret.append(pair) | |||||
| return ret | |||||
| def max_marginal_relevance_search( | |||||
| self, | |||||
| query: str, | |||||
| k: int = 4, | |||||
| fetch_k: int = 20, | |||||
| lambda_mult: float = 0.5, | |||||
| param: Optional[dict] = None, | |||||
| expr: Optional[str] = None, | |||||
| timeout: Optional[int] = None, | |||||
| **kwargs: Any, | |||||
| ) -> list[Document]: | |||||
| """Perform a search and return results that are reordered by MMR. | |||||
| Args: | |||||
| query (str): The text being searched. | |||||
| k (int, optional): How many results to give. Defaults to 4. | |||||
| fetch_k (int, optional): Total results to select k from. | |||||
| Defaults to 20. | |||||
| lambda_mult: Number between 0 and 1 that determines the degree | |||||
| of diversity among the results with 0 corresponding | |||||
| to maximum diversity and 1 to minimum diversity. | |||||
| Defaults to 0.5 | |||||
| param (dict, optional): The search params for the specified index. | |||||
| Defaults to None. | |||||
| expr (str, optional): Filtering expression. Defaults to None. | |||||
| timeout (int, optional): How long to wait before timeout error. | |||||
| Defaults to None. | |||||
| kwargs: Collection.search() keyword arguments. | |||||
| Returns: | |||||
| List[Document]: Document results for search. | |||||
| """ | |||||
| if self.col is None: | |||||
| logger.debug("No existing collection to search.") | |||||
| return [] | |||||
| embedding = self.embedding_func.embed_query(query) | |||||
| return self.max_marginal_relevance_search_by_vector( | |||||
| embedding=embedding, | |||||
| k=k, | |||||
| fetch_k=fetch_k, | |||||
| lambda_mult=lambda_mult, | |||||
| param=param, | |||||
| expr=expr, | |||||
| timeout=timeout, | |||||
| **kwargs, | |||||
| ) | |||||
| def max_marginal_relevance_search_by_vector( | |||||
| self, | |||||
| embedding: list[float], | |||||
| k: int = 4, | |||||
| fetch_k: int = 20, | |||||
| lambda_mult: float = 0.5, | |||||
| param: Optional[dict] = None, | |||||
| expr: Optional[str] = None, | |||||
| timeout: Optional[int] = None, | |||||
| **kwargs: Any, | |||||
| ) -> list[Document]: | |||||
| """Perform a search and return results that are reordered by MMR. | |||||
| Args: | |||||
| embedding (str): The embedding vector being searched. | |||||
| k (int, optional): How many results to give. Defaults to 4. | |||||
| fetch_k (int, optional): Total results to select k from. | |||||
| Defaults to 20. | |||||
| lambda_mult: Number between 0 and 1 that determines the degree | |||||
| of diversity among the results with 0 corresponding | |||||
| to maximum diversity and 1 to minimum diversity. | |||||
| Defaults to 0.5 | |||||
| param (dict, optional): The search params for the specified index. | |||||
| Defaults to None. | |||||
| expr (str, optional): Filtering expression. Defaults to None. | |||||
| timeout (int, optional): How long to wait before timeout error. | |||||
| Defaults to None. | |||||
| kwargs: Collection.search() keyword arguments. | |||||
| Returns: | |||||
| List[Document]: Document results for search. | |||||
| """ | |||||
| if self.col is None: | |||||
| logger.debug("No existing collection to search.") | |||||
| return [] | |||||
| if param is None: | |||||
| param = self.search_params | |||||
| # Determine result metadata fields. | |||||
| output_fields = self.fields[:] | |||||
| output_fields.remove(self._vector_field) | |||||
| # Perform the search. | |||||
| res = self.col.search( | |||||
| data=[embedding], | |||||
| anns_field=self._vector_field, | |||||
| param=param, | |||||
| limit=fetch_k, | |||||
| expr=expr, | |||||
| output_fields=output_fields, | |||||
| timeout=timeout, | |||||
| **kwargs, | |||||
| ) | |||||
| # Organize results. | |||||
| ids = [] | |||||
| documents = [] | |||||
| scores = [] | |||||
| for result in res[0]: | |||||
| meta = {x: result.entity.get(x) for x in output_fields} | |||||
| doc = Document(page_content=meta.pop(self._text_field), metadata=meta) | |||||
| documents.append(doc) | |||||
| scores.append(result.score) | |||||
| ids.append(result.id) | |||||
| vectors = self.col.query( | |||||
| expr=f"{self._primary_field} in {ids}", | |||||
| output_fields=[self._primary_field, self._vector_field], | |||||
| timeout=timeout, | |||||
| ) | |||||
| # Reorganize the results from query to match search order. | |||||
| vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors} | |||||
| ordered_result_embeddings = [vectors[x] for x in ids] | |||||
| # Get the new order of results. | |||||
| new_ordering = maximal_marginal_relevance( | |||||
| np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult | |||||
| ) | |||||
| # Reorder the values and return. | |||||
| ret = [] | |||||
| for x in new_ordering: | |||||
| # Function can return -1 index | |||||
| if x == -1: | |||||
| break | |||||
| else: | |||||
| ret.append(documents[x]) | |||||
| return ret | |||||
| @classmethod | |||||
| def from_texts( | |||||
| cls, | |||||
| texts: list[str], | |||||
| embedding: Embeddings, | |||||
| metadatas: Optional[list[dict]] = None, | |||||
| collection_name: str = "LangChainCollection", | |||||
| connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION, | |||||
| consistency_level: str = "Session", | |||||
| index_params: Optional[dict] = None, | |||||
| search_params: Optional[dict] = None, | |||||
| drop_old: bool = False, | |||||
| batch_size: int = 100, | |||||
| ids: Optional[Sequence[str]] = None, | |||||
| **kwargs: Any, | |||||
| ) -> Milvus: | |||||
| """Create a Milvus collection, indexes it with HNSW, and insert data. | |||||
| Args: | |||||
| texts (List[str]): Text data. | |||||
| embedding (Embeddings): Embedding function. | |||||
| metadatas (Optional[List[dict]]): Metadata for each text if it exists. | |||||
| Defaults to None. | |||||
| collection_name (str, optional): Collection name to use. Defaults to | |||||
| "LangChainCollection". | |||||
| connection_args (dict[str, Any], optional): Connection args to use. Defaults | |||||
| to DEFAULT_MILVUS_CONNECTION. | |||||
| consistency_level (str, optional): Which consistency level to use. Defaults | |||||
| to "Session". | |||||
| index_params (Optional[dict], optional): Which index_params to use. Defaults | |||||
| to None. | |||||
| search_params (Optional[dict], optional): Which search params to use. | |||||
| Defaults to None. | |||||
| drop_old (Optional[bool], optional): Whether to drop the collection with | |||||
| that name if it exists. Defaults to False. | |||||
| batch_size: | |||||
| How many vectors upload per-request. | |||||
| Default: 100 | |||||
| ids: Optional[Sequence[str]] = None, | |||||
| Returns: | |||||
| Milvus: Milvus Vector Store | |||||
| """ | |||||
| vector_db = cls( | |||||
| embedding_function=embedding, | |||||
| collection_name=collection_name, | |||||
| connection_args=connection_args, | |||||
| consistency_level=consistency_level, | |||||
| index_params=index_params, | |||||
| search_params=search_params, | |||||
| drop_old=drop_old, | |||||
| **kwargs, | |||||
| ) | |||||
| vector_db.add_texts(texts=texts, metadatas=metadatas, batch_size=batch_size) | |||||
| return vector_db |
| """Wrapper around weaviate vector database.""" | |||||
| from __future__ import annotations | |||||
| import datetime | |||||
| from collections.abc import Callable, Iterable | |||||
| from typing import Any, Optional | |||||
| from uuid import uuid4 | |||||
| import numpy as np | |||||
| from langchain.docstore.document import Document | |||||
| from langchain.embeddings.base import Embeddings | |||||
| from langchain.utils import get_from_dict_or_env | |||||
| from langchain.vectorstores.base import VectorStore | |||||
| from langchain.vectorstores.utils import maximal_marginal_relevance | |||||
| def _default_schema(index_name: str) -> dict: | |||||
| return { | |||||
| "class": index_name, | |||||
| "properties": [ | |||||
| { | |||||
| "name": "text", | |||||
| "dataType": ["text"], | |||||
| } | |||||
| ], | |||||
| } | |||||
| def _create_weaviate_client(**kwargs: Any) -> Any: | |||||
| client = kwargs.get("client") | |||||
| if client is not None: | |||||
| return client | |||||
| weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL") | |||||
| try: | |||||
| # the weaviate api key param should not be mandatory | |||||
| weaviate_api_key = get_from_dict_or_env( | |||||
| kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None | |||||
| ) | |||||
| except ValueError: | |||||
| weaviate_api_key = None | |||||
| try: | |||||
| import weaviate | |||||
| except ImportError: | |||||
| raise ValueError( | |||||
| "Could not import weaviate python package. " | |||||
| "Please install it with `pip install weaviate-client`" | |||||
| ) | |||||
| auth = ( | |||||
| weaviate.auth.AuthApiKey(api_key=weaviate_api_key) | |||||
| if weaviate_api_key is not None | |||||
| else None | |||||
| ) | |||||
| client = weaviate.Client(weaviate_url, auth_client_secret=auth) | |||||
| return client | |||||
| def _default_score_normalizer(val: float) -> float: | |||||
| return 1 - val | |||||
| def _json_serializable(value: Any) -> Any: | |||||
| if isinstance(value, datetime.datetime): | |||||
| return value.isoformat() | |||||
| return value | |||||
| class Weaviate(VectorStore): | |||||
| """Wrapper around Weaviate vector database. | |||||
| To use, you should have the ``weaviate-client`` python package installed. | |||||
| Example: | |||||
| .. code-block:: python | |||||
| import weaviate | |||||
| from langchain.vectorstores import Weaviate | |||||
| client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...) | |||||
| weaviate = Weaviate(client, index_name, text_key) | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| client: Any, | |||||
| index_name: str, | |||||
| text_key: str, | |||||
| embedding: Optional[Embeddings] = None, | |||||
| attributes: Optional[list[str]] = None, | |||||
| relevance_score_fn: Optional[ | |||||
| Callable[[float], float] | |||||
| ] = _default_score_normalizer, | |||||
| by_text: bool = True, | |||||
| ): | |||||
| """Initialize with Weaviate client.""" | |||||
| try: | |||||
| import weaviate | |||||
| except ImportError: | |||||
| raise ValueError( | |||||
| "Could not import weaviate python package. " | |||||
| "Please install it with `pip install weaviate-client`." | |||||
| ) | |||||
| if not isinstance(client, weaviate.Client): | |||||
| raise ValueError( | |||||
| f"client should be an instance of weaviate.Client, got {type(client)}" | |||||
| ) | |||||
| self._client = client | |||||
| self._index_name = index_name | |||||
| self._embedding = embedding | |||||
| self._text_key = text_key | |||||
| self._query_attrs = [self._text_key] | |||||
| self.relevance_score_fn = relevance_score_fn | |||||
| self._by_text = by_text | |||||
| if attributes is not None: | |||||
| self._query_attrs.extend(attributes) | |||||
| @property | |||||
| def embeddings(self) -> Optional[Embeddings]: | |||||
| return self._embedding | |||||
| def _select_relevance_score_fn(self) -> Callable[[float], float]: | |||||
| return ( | |||||
| self.relevance_score_fn | |||||
| if self.relevance_score_fn | |||||
| else _default_score_normalizer | |||||
| ) | |||||
| def add_texts( | |||||
| self, | |||||
| texts: Iterable[str], | |||||
| metadatas: Optional[list[dict]] = None, | |||||
| **kwargs: Any, | |||||
| ) -> list[str]: | |||||
| """Upload texts with metadata (properties) to Weaviate.""" | |||||
| from weaviate.util import get_valid_uuid | |||||
| ids = [] | |||||
| embeddings: Optional[list[list[float]]] = None | |||||
| if self._embedding: | |||||
| if not isinstance(texts, list): | |||||
| texts = list(texts) | |||||
| embeddings = self._embedding.embed_documents(texts) | |||||
| with self._client.batch as batch: | |||||
| for i, text in enumerate(texts): | |||||
| data_properties = {self._text_key: text} | |||||
| if metadatas is not None: | |||||
| for key, val in metadatas[i].items(): | |||||
| data_properties[key] = _json_serializable(val) | |||||
| # Allow for ids (consistent w/ other methods) | |||||
| # # Or uuids (backwards compatble w/ existing arg) | |||||
| # If the UUID of one of the objects already exists | |||||
| # then the existing object will be replaced by the new object. | |||||
| _id = get_valid_uuid(uuid4()) | |||||
| if "uuids" in kwargs: | |||||
| _id = kwargs["uuids"][i] | |||||
| elif "ids" in kwargs: | |||||
| _id = kwargs["ids"][i] | |||||
| batch.add_data_object( | |||||
| data_object=data_properties, | |||||
| class_name=self._index_name, | |||||
| uuid=_id, | |||||
| vector=embeddings[i] if embeddings else None, | |||||
| ) | |||||
| ids.append(_id) | |||||
| return ids | |||||
| def similarity_search( | |||||
| self, query: str, k: int = 4, **kwargs: Any | |||||
| ) -> list[Document]: | |||||
| """Return docs most similar to query. | |||||
| Args: | |||||
| query: Text to look up documents similar to. | |||||
| k: Number of Documents to return. Defaults to 4. | |||||
| Returns: | |||||
| List of Documents most similar to the query. | |||||
| """ | |||||
| if self._by_text: | |||||
| return self.similarity_search_by_text(query, k, **kwargs) | |||||
| else: | |||||
| if self._embedding is None: | |||||
| raise ValueError( | |||||
| "_embedding cannot be None for similarity_search when " | |||||
| "_by_text=False" | |||||
| ) | |||||
| embedding = self._embedding.embed_query(query) | |||||
| return self.similarity_search_by_vector(embedding, k, **kwargs) | |||||
| def similarity_search_by_text( | |||||
| self, query: str, k: int = 4, **kwargs: Any | |||||
| ) -> list[Document]: | |||||
| """Return docs most similar to query. | |||||
| Args: | |||||
| query: Text to look up documents similar to. | |||||
| k: Number of Documents to return. Defaults to 4. | |||||
| Returns: | |||||
| List of Documents most similar to the query. | |||||
| """ | |||||
| content: dict[str, Any] = {"concepts": [query]} | |||||
| if kwargs.get("search_distance"): | |||||
| content["certainty"] = kwargs.get("search_distance") | |||||
| query_obj = self._client.query.get(self._index_name, self._query_attrs) | |||||
| if kwargs.get("where_filter"): | |||||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||||
| if kwargs.get("additional"): | |||||
| query_obj = query_obj.with_additional(kwargs.get("additional")) | |||||
| result = query_obj.with_near_text(content).with_limit(k).do() | |||||
| if "errors" in result: | |||||
| raise ValueError(f"Error during query: {result['errors']}") | |||||
| docs = [] | |||||
| for res in result["data"]["Get"][self._index_name]: | |||||
| text = res.pop(self._text_key) | |||||
| docs.append(Document(page_content=text, metadata=res)) | |||||
| return docs | |||||
| def similarity_search_by_bm25( | |||||
| self, query: str, k: int = 4, **kwargs: Any | |||||
| ) -> list[Document]: | |||||
| """Return docs using BM25F. | |||||
| Args: | |||||
| query: Text to look up documents similar to. | |||||
| k: Number of Documents to return. Defaults to 4. | |||||
| Returns: | |||||
| List of Documents most similar to the query. | |||||
| """ | |||||
| content: dict[str, Any] = {"concepts": [query]} | |||||
| if kwargs.get("search_distance"): | |||||
| content["certainty"] = kwargs.get("search_distance") | |||||
| query_obj = self._client.query.get(self._index_name, self._query_attrs) | |||||
| if kwargs.get("where_filter"): | |||||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||||
| if kwargs.get("additional"): | |||||
| query_obj = query_obj.with_additional(kwargs.get("additional")) | |||||
| properties = ['text'] | |||||
| result = query_obj.with_bm25(query=query, properties=properties).with_limit(k).do() | |||||
| if "errors" in result: | |||||
| raise ValueError(f"Error during query: {result['errors']}") | |||||
| docs = [] | |||||
| for res in result["data"]["Get"][self._index_name]: | |||||
| text = res.pop(self._text_key) | |||||
| docs.append(Document(page_content=text, metadata=res)) | |||||
| return docs | |||||
| def similarity_search_by_vector( | |||||
| self, embedding: list[float], k: int = 4, **kwargs: Any | |||||
| ) -> list[Document]: | |||||
| """Look up similar documents by embedding vector in Weaviate.""" | |||||
| vector = {"vector": embedding} | |||||
| query_obj = self._client.query.get(self._index_name, self._query_attrs) | |||||
| if kwargs.get("where_filter"): | |||||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||||
| if kwargs.get("additional"): | |||||
| query_obj = query_obj.with_additional(kwargs.get("additional")) | |||||
| result = query_obj.with_near_vector(vector).with_limit(k).do() | |||||
| if "errors" in result: | |||||
| raise ValueError(f"Error during query: {result['errors']}") | |||||
| docs = [] | |||||
| for res in result["data"]["Get"][self._index_name]: | |||||
| text = res.pop(self._text_key) | |||||
| docs.append(Document(page_content=text, metadata=res)) | |||||
| return docs | |||||
| def max_marginal_relevance_search( | |||||
| self, | |||||
| query: str, | |||||
| k: int = 4, | |||||
| fetch_k: int = 20, | |||||
| lambda_mult: float = 0.5, | |||||
| **kwargs: Any, | |||||
| ) -> list[Document]: | |||||
| """Return docs selected using the maximal marginal relevance. | |||||
| Maximal marginal relevance optimizes for similarity to query AND diversity | |||||
| among selected documents. | |||||
| Args: | |||||
| query: Text to look up documents similar to. | |||||
| k: Number of Documents to return. Defaults to 4. | |||||
| fetch_k: Number of Documents to fetch to pass to MMR algorithm. | |||||
| lambda_mult: Number between 0 and 1 that determines the degree | |||||
| of diversity among the results with 0 corresponding | |||||
| to maximum diversity and 1 to minimum diversity. | |||||
| Defaults to 0.5. | |||||
| Returns: | |||||
| List of Documents selected by maximal marginal relevance. | |||||
| """ | |||||
| if self._embedding is not None: | |||||
| embedding = self._embedding.embed_query(query) | |||||
| else: | |||||
| raise ValueError( | |||||
| "max_marginal_relevance_search requires a suitable Embeddings object" | |||||
| ) | |||||
| return self.max_marginal_relevance_search_by_vector( | |||||
| embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs | |||||
| ) | |||||
| def max_marginal_relevance_search_by_vector( | |||||
| self, | |||||
| embedding: list[float], | |||||
| k: int = 4, | |||||
| fetch_k: int = 20, | |||||
| lambda_mult: float = 0.5, | |||||
| **kwargs: Any, | |||||
| ) -> list[Document]: | |||||
| """Return docs selected using the maximal marginal relevance. | |||||
| Maximal marginal relevance optimizes for similarity to query AND diversity | |||||
| among selected documents. | |||||
| Args: | |||||
| embedding: Embedding to look up documents similar to. | |||||
| k: Number of Documents to return. Defaults to 4. | |||||
| fetch_k: Number of Documents to fetch to pass to MMR algorithm. | |||||
| lambda_mult: Number between 0 and 1 that determines the degree | |||||
| of diversity among the results with 0 corresponding | |||||
| to maximum diversity and 1 to minimum diversity. | |||||
| Defaults to 0.5. | |||||
| Returns: | |||||
| List of Documents selected by maximal marginal relevance. | |||||
| """ | |||||
| vector = {"vector": embedding} | |||||
| query_obj = self._client.query.get(self._index_name, self._query_attrs) | |||||
| if kwargs.get("where_filter"): | |||||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||||
| results = ( | |||||
| query_obj.with_additional("vector") | |||||
| .with_near_vector(vector) | |||||
| .with_limit(fetch_k) | |||||
| .do() | |||||
| ) | |||||
| payload = results["data"]["Get"][self._index_name] | |||||
| embeddings = [result["_additional"]["vector"] for result in payload] | |||||
| mmr_selected = maximal_marginal_relevance( | |||||
| np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult | |||||
| ) | |||||
| docs = [] | |||||
| for idx in mmr_selected: | |||||
| text = payload[idx].pop(self._text_key) | |||||
| payload[idx].pop("_additional") | |||||
| meta = payload[idx] | |||||
| docs.append(Document(page_content=text, metadata=meta)) | |||||
| return docs | |||||
| def similarity_search_with_score( | |||||
| self, query: str, k: int = 4, **kwargs: Any | |||||
| ) -> list[tuple[Document, float]]: | |||||
| """ | |||||
| Return list of documents most similar to the query | |||||
| text and cosine distance in float for each. | |||||
| Lower score represents more similarity. | |||||
| """ | |||||
| if self._embedding is None: | |||||
| raise ValueError( | |||||
| "_embedding cannot be None for similarity_search_with_score" | |||||
| ) | |||||
| content: dict[str, Any] = {"concepts": [query]} | |||||
| if kwargs.get("search_distance"): | |||||
| content["certainty"] = kwargs.get("search_distance") | |||||
| query_obj = self._client.query.get(self._index_name, self._query_attrs) | |||||
| embedded_query = self._embedding.embed_query(query) | |||||
| if not self._by_text: | |||||
| vector = {"vector": embedded_query} | |||||
| result = ( | |||||
| query_obj.with_near_vector(vector) | |||||
| .with_limit(k) | |||||
| .with_additional(["vector", "distance"]) | |||||
| .do() | |||||
| ) | |||||
| else: | |||||
| result = ( | |||||
| query_obj.with_near_text(content) | |||||
| .with_limit(k) | |||||
| .with_additional(["vector", "distance"]) | |||||
| .do() | |||||
| ) | |||||
| if "errors" in result: | |||||
| raise ValueError(f"Error during query: {result['errors']}") | |||||
| docs_and_scores = [] | |||||
| for res in result["data"]["Get"][self._index_name]: | |||||
| text = res.pop(self._text_key) | |||||
| score = res["_additional"]["distance"] | |||||
| docs_and_scores.append((Document(page_content=text, metadata=res), score)) | |||||
| return docs_and_scores | |||||
| @classmethod | |||||
| def from_texts( | |||||
| cls: type[Weaviate], | |||||
| texts: list[str], | |||||
| embedding: Embeddings, | |||||
| metadatas: Optional[list[dict]] = None, | |||||
| **kwargs: Any, | |||||
| ) -> Weaviate: | |||||
| """Construct Weaviate wrapper from raw documents. | |||||
| This is a user-friendly interface that: | |||||
| 1. Embeds documents. | |||||
| 2. Creates a new index for the embeddings in the Weaviate instance. | |||||
| 3. Adds the documents to the newly created Weaviate index. | |||||
| This is intended to be a quick way to get started. | |||||
| Example: | |||||
| .. code-block:: python | |||||
| from langchain.vectorstores.weaviate import Weaviate | |||||
| from langchain.embeddings import OpenAIEmbeddings | |||||
| embeddings = OpenAIEmbeddings() | |||||
| weaviate = Weaviate.from_texts( | |||||
| texts, | |||||
| embeddings, | |||||
| weaviate_url="http://localhost:8080" | |||||
| ) | |||||
| """ | |||||
| client = _create_weaviate_client(**kwargs) | |||||
| from weaviate.util import get_valid_uuid | |||||
| index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}") | |||||
| embeddings = embedding.embed_documents(texts) if embedding else None | |||||
| text_key = "text" | |||||
| schema = _default_schema(index_name) | |||||
| attributes = list(metadatas[0].keys()) if metadatas else None | |||||
| # check whether the index already exists | |||||
| if not client.schema.contains(schema): | |||||
| client.schema.create_class(schema) | |||||
| with client.batch as batch: | |||||
| for i, text in enumerate(texts): | |||||
| data_properties = { | |||||
| text_key: text, | |||||
| } | |||||
| if metadatas is not None: | |||||
| for key in metadatas[i].keys(): | |||||
| data_properties[key] = metadatas[i][key] | |||||
| # If the UUID of one of the objects already exists | |||||
| # then the existing objectwill be replaced by the new object. | |||||
| if "uuids" in kwargs: | |||||
| _id = kwargs["uuids"][i] | |||||
| else: | |||||
| _id = get_valid_uuid(uuid4()) | |||||
| # if an embedding strategy is not provided, we let | |||||
| # weaviate create the embedding. Note that this will only | |||||
| # work if weaviate has been installed with a vectorizer module | |||||
| # like text2vec-contextionary for example | |||||
| params = { | |||||
| "uuid": _id, | |||||
| "data_object": data_properties, | |||||
| "class_name": index_name, | |||||
| } | |||||
| if embeddings is not None: | |||||
| params["vector"] = embeddings[i] | |||||
| batch.add_data_object(**params) | |||||
| batch.flush() | |||||
| relevance_score_fn = kwargs.get("relevance_score_fn") | |||||
| by_text: bool = kwargs.get("by_text", False) | |||||
| return cls( | |||||
| client, | |||||
| index_name, | |||||
| text_key, | |||||
| embedding=embedding, | |||||
| attributes=attributes, | |||||
| relevance_score_fn=relevance_score_fn, | |||||
| by_text=by_text, | |||||
| ) | |||||
| def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> None: | |||||
| """Delete by vector IDs. | |||||
| Args: | |||||
| ids: List of ids to delete. | |||||
| """ | |||||
| if ids is None: | |||||
| raise ValueError("No ids provided to delete.") | |||||
| # TODO: Check if this can be done in bulk | |||||
| for id in ids: | |||||
| self._client.data_object.delete(uuid=id) |
| from core.vector_store.vector.weaviate import Weaviate | |||||
| class WeaviateVectorStore(Weaviate): | |||||
| def del_texts(self, where_filter: dict): | |||||
| if not where_filter: | |||||
| raise ValueError('where_filter must not be empty') | |||||
| self._client.batch.delete_objects( | |||||
| class_name=self._index_name, | |||||
| where=where_filter, | |||||
| output='minimal' | |||||
| ) | |||||
| def del_text(self, uuid: str) -> None: | |||||
| self._client.data_object.delete( | |||||
| uuid, | |||||
| class_name=self._index_name | |||||
| ) | |||||
| def text_exists(self, uuid: str) -> bool: | |||||
| result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({ | |||||
| "path": ["doc_id"], | |||||
| "operator": "Equal", | |||||
| "valueText": uuid, | |||||
| }).with_limit(1).do() | |||||
| if "errors" in result: | |||||
| raise ValueError(f"Error during query: {result['errors']}") | |||||
| entries = result["data"]["Get"][self._index_name] | |||||
| if len(entries) == 0: | |||||
| return False | |||||
| return True | |||||
| def delete(self): | |||||
| self._client.schema.delete_class(self._index_name) |
| def handle(sender, **kwargs): | def handle(sender, **kwargs): | ||||
| dataset = sender | dataset = sender | ||||
| clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, | clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, | ||||
| dataset.index_struct, dataset.collection_binding_id) | |||||
| dataset.index_struct, dataset.collection_binding_id, dataset.doc_form) |
| def handle(sender, **kwargs): | def handle(sender, **kwargs): | ||||
| document_id = sender | document_id = sender | ||||
| dataset_id = kwargs.get('dataset_id') | dataset_id = kwargs.get('dataset_id') | ||||
| clean_document_task.delay(document_id, dataset_id) | |||||
| doc_form = kwargs.get('doc_form') | |||||
| clean_document_task.delay(document_id, dataset_id, doc_form) |
| return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \ | return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \ | ||||
| .filter(Document.dataset_id == self.id).scalar() | .filter(Document.dataset_id == self.id).scalar() | ||||
| @property | |||||
| def doc_form(self): | |||||
| document = db.session.query(Document).filter( | |||||
| Document.dataset_id == self.id).first() | |||||
| if document: | |||||
| return document.doc_form | |||||
| return None | |||||
| @property | @property | ||||
| def retrieval_model_dict(self): | def retrieval_model_dict(self): | ||||
| default_retrieval_model = { | default_retrieval_model = { |
| from werkzeug.exceptions import NotFound | from werkzeug.exceptions import NotFound | ||||
| import app | import app | ||||
| from core.index.index import IndexBuilder | |||||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import Dataset, DatasetQuery, Document | from models.dataset import Dataset, DatasetQuery, Document | ||||
| if not documents or len(documents) == 0: | if not documents or len(documents) == 0: | ||||
| try: | try: | ||||
| # remove index | # remove index | ||||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||||
| kw_index = IndexBuilder.get_index(dataset, 'economy') | |||||
| # delete from vector index | |||||
| if vector_index: | |||||
| if dataset.collection_binding_id: | |||||
| vector_index.delete_by_group_id(dataset.id) | |||||
| else: | |||||
| if dataset.collection_binding_id: | |||||
| vector_index.delete_by_group_id(dataset.id) | |||||
| else: | |||||
| vector_index.delete() | |||||
| kw_index.delete() | |||||
| index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() | |||||
| index_processor.clean(dataset, None) | |||||
| # update document | # update document | ||||
| update_params = { | update_params = { | ||||
| Document.enabled: False | Document.enabled: False |
| from sqlalchemy import func | from sqlalchemy import func | ||||
| from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | ||||
| from core.index.index import IndexBuilder | |||||
| from core.model_manager import ModelManager | from core.model_manager import ModelManager | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | ||||
| from core.rag.datasource.keyword.keyword_factory import Keyword | |||||
| from core.rag.models.document import Document as RAGDocument | |||||
| from events.dataset_event import dataset_was_deleted | from events.dataset_event import dataset_was_deleted | ||||
| from events.document_event import document_was_deleted | from events.document_event import document_was_deleted | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| @staticmethod | @staticmethod | ||||
| def delete_document(document): | def delete_document(document): | ||||
| # trigger document_was_deleted signal | # trigger document_was_deleted signal | ||||
| document_was_deleted.send(document.id, dataset_id=document.dataset_id) | |||||
| document_was_deleted.send(document.id, dataset_id=document.dataset_id, doc_form=document.doc_form) | |||||
| db.session.delete(document) | db.session.delete(document) | ||||
| db.session.commit() | db.session.commit() | ||||
| # save vector index | # save vector index | ||||
| try: | try: | ||||
| VectorService.create_segment_vector(args['keywords'], segment_document, dataset) | |||||
| VectorService.create_segments_vector([args['keywords']], [segment_document], dataset) | |||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception("create segment index failed") | logging.exception("create segment index failed") | ||||
| segment_document.enabled = False | segment_document.enabled = False | ||||
| ).scalar() | ).scalar() | ||||
| pre_segment_data_list = [] | pre_segment_data_list = [] | ||||
| segment_data_list = [] | segment_data_list = [] | ||||
| keywords_list = [] | |||||
| for segment_item in segments: | for segment_item in segments: | ||||
| content = segment_item['content'] | content = segment_item['content'] | ||||
| doc_id = str(uuid.uuid4()) | doc_id = str(uuid.uuid4()) | ||||
| segment_document.answer = segment_item['answer'] | segment_document.answer = segment_item['answer'] | ||||
| db.session.add(segment_document) | db.session.add(segment_document) | ||||
| segment_data_list.append(segment_document) | segment_data_list.append(segment_document) | ||||
| pre_segment_data = { | |||||
| 'segment': segment_document, | |||||
| 'keywords': segment_item['keywords'] | |||||
| } | |||||
| pre_segment_data_list.append(pre_segment_data) | |||||
| pre_segment_data_list.append(segment_document) | |||||
| keywords_list.append(segment_item['keywords']) | |||||
| try: | try: | ||||
| # save vector index | # save vector index | ||||
| VectorService.multi_create_segment_vector(pre_segment_data_list, dataset) | |||||
| VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset) | |||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception("create segment index failed") | logging.exception("create segment index failed") | ||||
| for segment_document in segment_data_list: | for segment_document in segment_data_list: | ||||
| db.session.commit() | db.session.commit() | ||||
| # update segment index task | # update segment index task | ||||
| if args['keywords']: | if args['keywords']: | ||||
| kw_index = IndexBuilder.get_index(dataset, 'economy') | |||||
| # delete from keyword index | |||||
| kw_index.delete_by_ids([segment.index_node_id]) | |||||
| # save keyword index | |||||
| kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords) | |||||
| keyword = Keyword(dataset) | |||||
| keyword.delete_by_ids([segment.index_node_id]) | |||||
| document = RAGDocument( | |||||
| page_content=segment.content, | |||||
| metadata={ | |||||
| "doc_id": segment.index_node_id, | |||||
| "doc_hash": segment.index_node_hash, | |||||
| "document_id": segment.document_id, | |||||
| "dataset_id": segment.dataset_id, | |||||
| } | |||||
| ) | |||||
| keyword.add_texts([document], keywords_list=[args['keywords']]) | |||||
| else: | else: | ||||
| segment_hash = helper.generate_text_hash(content) | segment_hash = helper.generate_text_hash(content) | ||||
| tokens = 0 | tokens = 0 |
| from werkzeug.datastructures import FileStorage | from werkzeug.datastructures import FileStorage | ||||
| from werkzeug.exceptions import NotFound | from werkzeug.exceptions import NotFound | ||||
| from core.data_loader.file_extractor import FileExtractor | |||||
| from core.file.upload_file_parser import UploadFileParser | from core.file.upload_file_parser import UploadFileParser | ||||
| from core.rag.extractor.extract_processor import ExtractProcessor | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_storage import storage | from extensions.ext_storage import storage | ||||
| from models.account import Account | from models.account import Account | ||||
| def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile: | def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile: | ||||
| extension = file.filename.split('.')[-1] | extension = file.filename.split('.')[-1] | ||||
| etl_type = current_app.config['ETL_TYPE'] | etl_type = current_app.config['ETL_TYPE'] | ||||
| allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS | |||||
| allowed_extensions = UNSTRUSTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if etl_type == 'Unstructured' \ | |||||
| else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS | |||||
| if extension.lower() not in allowed_extensions: | if extension.lower() not in allowed_extensions: | ||||
| raise UnsupportedFileTypeError() | raise UnsupportedFileTypeError() | ||||
| elif only_image and extension.lower() not in IMAGE_EXTENSIONS: | elif only_image and extension.lower() not in IMAGE_EXTENSIONS: | ||||
| if extension.lower() not in allowed_extensions: | if extension.lower() not in allowed_extensions: | ||||
| raise UnsupportedFileTypeError() | raise UnsupportedFileTypeError() | ||||
| text = FileExtractor.load(upload_file, return_text=True) | |||||
| text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True) | |||||
| text = text[0:PREVIEW_WORDS_LIMIT] if text else '' | text = text[0:PREVIEW_WORDS_LIMIT] if text else '' | ||||
| return text | return text | ||||
| return generator, upload_file.mime_type | return generator, upload_file.mime_type | ||||
| @staticmethod | @staticmethod | ||||
| def get_public_image_preview(file_id: str) -> str: | |||||
| def get_public_image_preview(file_id: str) -> tuple[Generator, str]: | |||||
| upload_file = db.session.query(UploadFile) \ | upload_file = db.session.query(UploadFile) \ | ||||
| .filter(UploadFile.id == file_id) \ | .filter(UploadFile.id == file_id) \ | ||||
| .first() | .first() |
| import logging | import logging | ||||
| import threading | |||||
| import time | import time | ||||
| import numpy as np | import numpy as np | ||||
| from flask import current_app | |||||
| from langchain.embeddings.base import Embeddings | |||||
| from langchain.schema import Document | |||||
| from sklearn.manifold import TSNE | from sklearn.manifold import TSNE | ||||
| from core.embedding.cached_embedding import CacheEmbedding | from core.embedding.cached_embedding import CacheEmbedding | ||||
| from core.model_manager import ModelManager | from core.model_manager import ModelManager | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.rerank.rerank import RerankRunner | |||||
| from core.rag.datasource.entity.embedding import Embeddings | |||||
| from core.rag.datasource.retrieval_service import RetrievalService | |||||
| from core.rag.models.document import Document | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.account import Account | from models.account import Account | ||||
| from models.dataset import Dataset, DatasetQuery, DocumentSegment | from models.dataset import Dataset, DatasetQuery, DocumentSegment | ||||
| from services.retrieval_service import RetrievalService | |||||
| default_retrieval_model = { | default_retrieval_model = { | ||||
| 'search_method': 'semantic_search', | 'search_method': 'semantic_search', | ||||
| 'score_threshold_enabled': False | 'score_threshold_enabled': False | ||||
| } | } | ||||
| class HitTestingService: | class HitTestingService: | ||||
| @classmethod | @classmethod | ||||
| def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict: | def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict: | ||||
| embeddings = CacheEmbedding(embedding_model) | embeddings = CacheEmbedding(embedding_model) | ||||
| all_documents = [] | |||||
| threads = [] | |||||
| # retrieval_model source with semantic | |||||
| if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': | |||||
| embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ | |||||
| 'flask_app': current_app._get_current_object(), | |||||
| 'dataset_id': str(dataset.id), | |||||
| 'query': query, | |||||
| 'top_k': retrieval_model['top_k'], | |||||
| 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None, | |||||
| 'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None, | |||||
| 'all_documents': all_documents, | |||||
| 'search_method': retrieval_model['search_method'], | |||||
| 'embeddings': embeddings | |||||
| }) | |||||
| threads.append(embedding_thread) | |||||
| embedding_thread.start() | |||||
| # retrieval source with full text | |||||
| if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': | |||||
| full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ | |||||
| 'flask_app': current_app._get_current_object(), | |||||
| 'dataset_id': str(dataset.id), | |||||
| 'query': query, | |||||
| 'search_method': retrieval_model['search_method'], | |||||
| 'embeddings': embeddings, | |||||
| 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None, | |||||
| 'top_k': retrieval_model['top_k'], | |||||
| 'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None, | |||||
| 'all_documents': all_documents | |||||
| }) | |||||
| threads.append(full_text_index_thread) | |||||
| full_text_index_thread.start() | |||||
| for thread in threads: | |||||
| thread.join() | |||||
| if retrieval_model['search_method'] == 'hybrid_search': | |||||
| model_manager = ModelManager() | |||||
| rerank_model_instance = model_manager.get_model_instance( | |||||
| tenant_id=dataset.tenant_id, | |||||
| provider=retrieval_model['reranking_model']['reranking_provider_name'], | |||||
| model_type=ModelType.RERANK, | |||||
| model=retrieval_model['reranking_model']['reranking_model_name'] | |||||
| ) | |||||
| rerank_runner = RerankRunner(rerank_model_instance) | |||||
| all_documents = rerank_runner.run( | |||||
| query=query, | |||||
| documents=all_documents, | |||||
| score_threshold=retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None, | |||||
| top_n=retrieval_model['top_k'], | |||||
| user=f"account-{account.id}" | |||||
| ) | |||||
| all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], | |||||
| dataset_id=dataset.id, | |||||
| query=query, | |||||
| top_k=retrieval_model['top_k'], | |||||
| score_threshold=retrieval_model['score_threshold'] | |||||
| if retrieval_model['score_threshold_enabled'] else None, | |||||
| reranking_model=retrieval_model['reranking_model'] | |||||
| if retrieval_model['reranking_enable'] else None | |||||
| ) | |||||
| end = time.perf_counter() | end = time.perf_counter() | ||||
| logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") | logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") | ||||
| if not query or len(query) > 250: | if not query or len(query) > 250: | ||||
| raise ValueError('Query is required and cannot exceed 250 characters') | raise ValueError('Query is required and cannot exceed 250 characters') | ||||
| from typing import Optional | |||||
| from flask import Flask, current_app | |||||
| from langchain.embeddings.base import Embeddings | |||||
| from core.index.vector_index.vector_index import VectorIndex | |||||
| from core.model_manager import ModelManager | |||||
| from core.model_runtime.entities.model_entities import ModelType | |||||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||||
| from core.rerank.rerank import RerankRunner | |||||
| from extensions.ext_database import db | |||||
| from models.dataset import Dataset | |||||
| default_retrieval_model = { | |||||
| 'search_method': 'semantic_search', | |||||
| 'reranking_enable': False, | |||||
| 'reranking_model': { | |||||
| 'reranking_provider_name': '', | |||||
| 'reranking_model_name': '' | |||||
| }, | |||||
| 'top_k': 2, | |||||
| 'score_threshold_enabled': False | |||||
| } | |||||
| class RetrievalService: | |||||
| @classmethod | |||||
| def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, | |||||
| top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], | |||||
| all_documents: list, search_method: str, embeddings: Embeddings): | |||||
| with flask_app.app_context(): | |||||
| dataset = db.session.query(Dataset).filter( | |||||
| Dataset.id == dataset_id | |||||
| ).first() | |||||
| vector_index = VectorIndex( | |||||
| dataset=dataset, | |||||
| config=current_app.config, | |||||
| embeddings=embeddings | |||||
| ) | |||||
| documents = vector_index.search( | |||||
| query, | |||||
| search_type='similarity_score_threshold', | |||||
| search_kwargs={ | |||||
| 'k': top_k, | |||||
| 'score_threshold': score_threshold, | |||||
| 'filter': { | |||||
| 'group_id': [dataset.id] | |||||
| } | |||||
| } | |||||
| ) | |||||
| if documents: | |||||
| if reranking_model and search_method == 'semantic_search': | |||||
| try: | |||||
| model_manager = ModelManager() | |||||
| rerank_model_instance = model_manager.get_model_instance( | |||||
| tenant_id=dataset.tenant_id, | |||||
| provider=reranking_model['reranking_provider_name'], | |||||
| model_type=ModelType.RERANK, | |||||
| model=reranking_model['reranking_model_name'] | |||||
| ) | |||||
| except InvokeAuthorizationError: | |||||
| return | |||||
| rerank_runner = RerankRunner(rerank_model_instance) | |||||
| all_documents.extend(rerank_runner.run( | |||||
| query=query, | |||||
| documents=documents, | |||||
| score_threshold=score_threshold, | |||||
| top_n=len(documents) | |||||
| )) | |||||
| else: | |||||
| all_documents.extend(documents) | |||||
| @classmethod | |||||
| def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, | |||||
| top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], | |||||
| all_documents: list, search_method: str, embeddings: Embeddings): | |||||
| with flask_app.app_context(): | |||||
| dataset = db.session.query(Dataset).filter( | |||||
| Dataset.id == dataset_id | |||||
| ).first() | |||||
| vector_index = VectorIndex( | |||||
| dataset=dataset, | |||||
| config=current_app.config, | |||||
| embeddings=embeddings | |||||
| ) | |||||
| documents = vector_index.search_by_full_text_index( | |||||
| query, | |||||
| search_type='similarity_score_threshold', | |||||
| top_k=top_k | |||||
| ) | |||||
| if documents: | |||||
| if reranking_model and search_method == 'full_text_search': | |||||
| try: | |||||
| model_manager = ModelManager() | |||||
| rerank_model_instance = model_manager.get_model_instance( | |||||
| tenant_id=dataset.tenant_id, | |||||
| provider=reranking_model['reranking_provider_name'], | |||||
| model_type=ModelType.RERANK, | |||||
| model=reranking_model['reranking_model_name'] | |||||
| ) | |||||
| except InvokeAuthorizationError: | |||||
| return | |||||
| rerank_runner = RerankRunner(rerank_model_instance) | |||||
| all_documents.extend(rerank_runner.run( | |||||
| query=query, | |||||
| documents=documents, | |||||
| score_threshold=score_threshold, | |||||
| top_n=len(documents) | |||||
| )) | |||||
| else: | |||||
| all_documents.extend(documents) |
| from typing import Optional | from typing import Optional | ||||
| from langchain.schema import Document | |||||
| from core.index.index import IndexBuilder | |||||
| from core.rag.datasource.keyword.keyword_factory import Keyword | |||||
| from core.rag.datasource.vdb.vector_factory import Vector | |||||
| from core.rag.models.document import Document | |||||
| from models.dataset import Dataset, DocumentSegment | from models.dataset import Dataset, DocumentSegment | ||||
| class VectorService: | class VectorService: | ||||
| @classmethod | @classmethod | ||||
| def create_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): | |||||
| document = Document( | |||||
| page_content=segment.content, | |||||
| metadata={ | |||||
| "doc_id": segment.index_node_id, | |||||
| "doc_hash": segment.index_node_hash, | |||||
| "document_id": segment.document_id, | |||||
| "dataset_id": segment.dataset_id, | |||||
| } | |||||
| ) | |||||
| # save vector index | |||||
| index = IndexBuilder.get_index(dataset, 'high_quality') | |||||
| if index: | |||||
| index.add_texts([document], duplicate_check=True) | |||||
| # save keyword index | |||||
| index = IndexBuilder.get_index(dataset, 'economy') | |||||
| if index: | |||||
| if keywords and len(keywords) > 0: | |||||
| index.create_segment_keywords(segment.index_node_id, keywords) | |||||
| else: | |||||
| index.add_texts([document]) | |||||
| @classmethod | |||||
| def multi_create_segment_vector(cls, pre_segment_data_list: list, dataset: Dataset): | |||||
| def create_segments_vector(cls, keywords_list: Optional[list[list[str]]], | |||||
| segments: list[DocumentSegment], dataset: Dataset): | |||||
| documents = [] | documents = [] | ||||
| for pre_segment_data in pre_segment_data_list: | |||||
| segment = pre_segment_data['segment'] | |||||
| for segment in segments: | |||||
| document = Document( | document = Document( | ||||
| page_content=segment.content, | page_content=segment.content, | ||||
| metadata={ | metadata={ | ||||
| } | } | ||||
| ) | ) | ||||
| documents.append(document) | documents.append(document) | ||||
| # save vector index | |||||
| index = IndexBuilder.get_index(dataset, 'high_quality') | |||||
| if index: | |||||
| index.add_texts(documents, duplicate_check=True) | |||||
| if dataset.indexing_technique == 'high_quality': | |||||
| # save vector index | |||||
| vector = Vector( | |||||
| dataset=dataset | |||||
| ) | |||||
| vector.add_texts(documents, duplicate_check=True) | |||||
| # save keyword index | # save keyword index | ||||
| keyword_index = IndexBuilder.get_index(dataset, 'economy') | |||||
| if keyword_index: | |||||
| keyword_index.multi_create_segment_keywords(pre_segment_data_list) | |||||
| keyword = Keyword(dataset) | |||||
| if keywords_list and len(keywords_list) > 0: | |||||
| keyword.add_texts(documents, keyword_list=keywords_list) | |||||
| else: | |||||
| keyword.add_texts(documents) | |||||
| @classmethod | @classmethod | ||||
| def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): | def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): | ||||
| # update segment index task | # update segment index task | ||||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||||
| kw_index = IndexBuilder.get_index(dataset, 'economy') | |||||
| # delete from vector index | |||||
| if vector_index: | |||||
| vector_index.delete_by_ids([segment.index_node_id]) | |||||
| # delete from keyword index | |||||
| kw_index.delete_by_ids([segment.index_node_id]) | |||||
| # add new index | |||||
| # format new index | |||||
| document = Document( | document = Document( | ||||
| page_content=segment.content, | page_content=segment.content, | ||||
| metadata={ | metadata={ | ||||
| "dataset_id": segment.dataset_id, | "dataset_id": segment.dataset_id, | ||||
| } | } | ||||
| ) | ) | ||||
| if dataset.indexing_technique == 'high_quality': | |||||
| # update vector index | |||||
| vector = Vector( | |||||
| dataset=dataset | |||||
| ) | |||||
| vector.delete_by_ids([segment.index_node_id]) | |||||
| vector.add_texts([document], duplicate_check=True) | |||||
| # save vector index | |||||
| if vector_index: | |||||
| vector_index.add_texts([document], duplicate_check=True) | |||||
| # update keyword index | |||||
| keyword = Keyword(dataset) | |||||
| keyword.delete_by_ids([segment.index_node_id]) | |||||
| # save keyword index | # save keyword index | ||||
| if keywords and len(keywords) > 0: | if keywords and len(keywords) > 0: | ||||
| kw_index.create_segment_keywords(segment.index_node_id, keywords) | |||||
| keyword.add_texts([document], keywords_list=[keywords]) | |||||
| else: | else: | ||||
| kw_index.add_texts([document]) | |||||
| keyword.add_texts([document]) |