Co-authored-by: jyong <718720800@qq.com>tags/0.3.6
| @@ -14,7 +14,7 @@ from flask import Flask, request, Response, session | |||
| import flask_login | |||
| from flask_cors import CORS | |||
| from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_vector_store, ext_migrate, \ | |||
| from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \ | |||
| ext_database, ext_storage | |||
| from extensions.ext_database import db | |||
| from extensions.ext_login import login_manager | |||
| @@ -79,7 +79,6 @@ def initialize_extensions(app): | |||
| ext_database.init_app(app) | |||
| ext_migrate.init(app, db) | |||
| ext_redis.init_app(app) | |||
| ext_vector_store.init_app(app) | |||
| ext_storage.init_app(app) | |||
| ext_celery.init_app(app) | |||
| ext_session.init_app(app) | |||
| @@ -1,15 +1,19 @@ | |||
| import datetime | |||
| import logging | |||
| import random | |||
| import string | |||
| import click | |||
| from flask import current_app | |||
| from werkzeug.exceptions import NotFound | |||
| from core.index.index import IndexBuilder | |||
| from libs.password import password_pattern, valid_password, hash_password | |||
| from libs.helper import email as email_validate | |||
| from extensions.ext_database import db | |||
| from libs.rsa import generate_key_pair | |||
| from models.account import InvitationCode, Tenant | |||
| from models.dataset import Dataset | |||
| from models.model import Account | |||
| import secrets | |||
| import base64 | |||
| @@ -159,8 +163,39 @@ def generate_upper_string(): | |||
| return result | |||
| @click.command('recreate-all-dataset-indexes', help='Recreate all dataset indexes.') | |||
| def recreate_all_dataset_indexes(): | |||
| click.echo(click.style('Start recreate all dataset indexes.', fg='green')) | |||
| recreate_count = 0 | |||
| page = 1 | |||
| while True: | |||
| try: | |||
| datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality')\ | |||
| .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50) | |||
| except NotFound: | |||
| break | |||
| page += 1 | |||
| for dataset in datasets: | |||
| try: | |||
| click.echo('Recreating dataset index: {}'.format(dataset.id)) | |||
| index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| if index and index._is_origin(): | |||
| index.recreate_dataset(dataset) | |||
| recreate_count += 1 | |||
| else: | |||
| click.echo('passed.') | |||
| except Exception as e: | |||
| click.echo(click.style('Recreate dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red')) | |||
| continue | |||
| click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green')) | |||
| def register_commands(app): | |||
| app.cli.add_command(reset_password) | |||
| app.cli.add_command(reset_email) | |||
| app.cli.add_command(generate_invitation_codes) | |||
| app.cli.add_command(reset_encrypt_key_pair) | |||
| app.cli.add_command(recreate_all_dataset_indexes) | |||
| @@ -187,11 +187,13 @@ class Config: | |||
| # For temp use only | |||
| # set default LLM provider, default is 'openai', support `azure_openai` | |||
| self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER') | |||
| # notion import setting | |||
| self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID') | |||
| self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET') | |||
| self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE') | |||
| self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET') | |||
| self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN') | |||
| class CloudEditionConfig(Config): | |||
| @@ -10,11 +10,10 @@ from werkzeug.exceptions import NotFound | |||
| from controllers.console import api | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.data_source.notion import NotionPageReader | |||
| from core.data_loader.loader.notion import NotionLoader | |||
| from core.indexing_runner import IndexingRunner | |||
| from extensions.ext_database import db | |||
| from libs.helper import TimestampField | |||
| from libs.oauth_data_source import NotionOAuth | |||
| from models.dataset import Document | |||
| from models.source import DataSourceBinding | |||
| from services.dataset_service import DatasetService, DocumentService | |||
| @@ -232,15 +231,17 @@ class DataSourceNotionApi(Resource): | |||
| ).first() | |||
| if not data_source_binding: | |||
| raise NotFound('Data source binding not found.') | |||
| reader = NotionPageReader(integration_token=data_source_binding.access_token) | |||
| if page_type == 'page': | |||
| page_content = reader.read_page(page_id) | |||
| elif page_type == 'database': | |||
| page_content = reader.query_database_data(page_id) | |||
| else: | |||
| page_content = "" | |||
| loader = NotionLoader( | |||
| notion_access_token=data_source_binding.access_token, | |||
| notion_workspace_id=workspace_id, | |||
| notion_obj_id=page_id, | |||
| notion_page_type=page_type | |||
| ) | |||
| text_docs = loader.load() | |||
| return { | |||
| 'content': page_content | |||
| 'content': "\n".join([doc.page_content for doc in text_docs]) | |||
| }, 200 | |||
| @setup_required | |||
| @@ -17,9 +17,7 @@ from controllers.console.datasets.error import NoFileUploadedError, TooManyFiles | |||
| UnsupportedFileTypeError | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.index.readers.html_parser import HTMLParser | |||
| from core.index.readers.pdf_parser import PDFParser | |||
| from core.index.readers.xlsx_parser import XLSXParser | |||
| from core.data_loader.file_extractor import FileExtractor | |||
| from extensions.ext_storage import storage | |||
| from libs.helper import TimestampField | |||
| from extensions.ext_database import db | |||
| @@ -123,31 +121,7 @@ class FilePreviewApi(Resource): | |||
| if extension not in ALLOWED_EXTENSIONS: | |||
| raise UnsupportedFileTypeError() | |||
| with tempfile.TemporaryDirectory() as temp_dir: | |||
| suffix = Path(upload_file.key).suffix | |||
| filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" | |||
| storage.download(upload_file.key, filepath) | |||
| if extension == 'pdf': | |||
| parser = PDFParser({'upload_file': upload_file}) | |||
| text = parser.parse_file(Path(filepath)) | |||
| elif extension in ['html', 'htm']: | |||
| # Use BeautifulSoup to extract text | |||
| parser = HTMLParser() | |||
| text = parser.parse_file(Path(filepath)) | |||
| elif extension == 'xlsx': | |||
| parser = XLSXParser() | |||
| text = parser.parse_file(filepath) | |||
| else: | |||
| # ['txt', 'markdown', 'md'] | |||
| with open(filepath, "rb") as fp: | |||
| data = fp.read() | |||
| encoding = chardet.detect(data)['encoding'] | |||
| if encoding: | |||
| text = data.decode(encoding=encoding).strip() if data else '' | |||
| else: | |||
| text = data.decode(encoding='utf-8').strip() if data else '' | |||
| text = FileExtractor.load(upload_file, return_text=True) | |||
| text = text[0:PREVIEW_WORDS_LIMIT] if text else '' | |||
| return {'content': text} | |||
| @@ -32,8 +32,13 @@ class VersionApi(Resource): | |||
| 'current_version': args.get('current_version') | |||
| }) | |||
| except Exception as error: | |||
| logging.exception("Check update error.") | |||
| raise InternalServerError() | |||
| logging.warning("Check update version error: {}.".format(str(error))) | |||
| return { | |||
| 'version': args.get('current_version'), | |||
| 'release_date': '', | |||
| 'release_notes': '', | |||
| 'can_auto_update': False | |||
| } | |||
| content = json.loads(response.content) | |||
| return { | |||
| @@ -3,19 +3,11 @@ from typing import Optional | |||
| import langchain | |||
| from flask import Flask | |||
| from jieba.analyse import default_tfidf | |||
| from langchain import set_handler | |||
| from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING | |||
| from llama_index import IndexStructType, QueryMode | |||
| from llama_index.indices.registry import INDEX_STRUT_TYPE_TO_QUERY_MAP | |||
| from pydantic import BaseModel | |||
| from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler | |||
| from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex | |||
| from core.index.keyword_table.stopwords import STOPWORDS | |||
| from core.prompt.prompt_template import OneLineFormatter | |||
| from core.vector_store.vector_store import VectorStore | |||
| from core.vector_store.vector_store_index_query import EnhanceGPTVectorStoreIndexQuery | |||
| class HostedOpenAICredential(BaseModel): | |||
| @@ -32,21 +24,9 @@ hosted_llm_credentials = HostedLLMCredentials() | |||
| def init_app(app: Flask): | |||
| formatter = OneLineFormatter() | |||
| DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format | |||
| INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.KEYWORD_TABLE] = GPTJIEBAKeywordTableIndex.get_query_map() | |||
| INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.WEAVIATE] = { | |||
| QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery, | |||
| QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery, | |||
| } | |||
| INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.QDRANT] = { | |||
| QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery, | |||
| QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery, | |||
| } | |||
| default_tfidf.stop_words = STOPWORDS | |||
| if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': | |||
| langchain.verbose = True | |||
| set_handler(DifyStdOutCallbackHandler()) | |||
| if app.config.get("OPENAI_API_KEY"): | |||
| hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY")) | |||
| @@ -2,7 +2,7 @@ from typing import Optional | |||
| from langchain import LLMChain | |||
| from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent | |||
| from langchain.callbacks import CallbackManager | |||
| from langchain.callbacks.manager import CallbackManager | |||
| from langchain.memory.chat_memory import BaseChatMemory | |||
| from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler | |||
| @@ -16,23 +16,20 @@ class AgentBuilder: | |||
| def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory], | |||
| dataset_tool_callback_handler: DatasetToolCallbackHandler, | |||
| agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler): | |||
| llm_callback_manager = CallbackManager([agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]) | |||
| llm = LLMBuilder.to_llm( | |||
| tenant_id=tenant_id, | |||
| model_name=agent_loop_gather_callback_handler.model_name, | |||
| temperature=0, | |||
| max_tokens=1024, | |||
| callback_manager=llm_callback_manager | |||
| callbacks=[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()] | |||
| ) | |||
| tool_callback_manager = CallbackManager([ | |||
| agent_loop_gather_callback_handler, | |||
| dataset_tool_callback_handler, | |||
| DifyStdOutCallbackHandler() | |||
| ]) | |||
| for tool in tools: | |||
| tool.callback_manager = tool_callback_manager | |||
| tool.callbacks = [ | |||
| agent_loop_gather_callback_handler, | |||
| dataset_tool_callback_handler, | |||
| DifyStdOutCallbackHandler() | |||
| ] | |||
| prompt = cls.build_agent_prompt_template( | |||
| tools=tools, | |||
| @@ -54,7 +51,7 @@ class AgentBuilder: | |||
| tools=tools, | |||
| agent=agent, | |||
| memory=memory, | |||
| callback_manager=agent_callback_manager, | |||
| callbacks=agent_callback_manager, | |||
| max_iterations=6, | |||
| early_stopping_method="generate", | |||
| # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit | |||
| @@ -12,6 +12,7 @@ from core.conversation_message_task import ConversationMessageTask | |||
| class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| """Callback Handler that prints to std out.""" | |||
| raise_error: bool = True | |||
| def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None: | |||
| """Initialize callback handler.""" | |||
| @@ -64,10 +65,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| self._current_loop.completion = response.generations[0][0].text | |||
| self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens'] | |||
| def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |||
| """Do nothing.""" | |||
| pass | |||
| def on_llm_error( | |||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |||
| ) -> None: | |||
| @@ -75,21 +72,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| self._agent_loops = [] | |||
| self._current_loop = None | |||
| def on_chain_start( | |||
| self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any | |||
| ) -> None: | |||
| """Print out that we are entering a chain.""" | |||
| pass | |||
| def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | |||
| """Print out that we finished a chain.""" | |||
| pass | |||
| def on_chain_error( | |||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |||
| ) -> None: | |||
| logging.error(error) | |||
| def on_tool_start( | |||
| self, | |||
| serialized: Dict[str, Any], | |||
| @@ -151,16 +133,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| self._agent_loops = [] | |||
| self._current_loop = None | |||
| def on_text( | |||
| self, | |||
| text: str, | |||
| color: Optional[str] = None, | |||
| end: str = "", | |||
| **kwargs: Optional[str], | |||
| ) -> None: | |||
| """Run on additional input from chains and agents.""" | |||
| pass | |||
| def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: | |||
| """Run on agent end.""" | |||
| # Final Answer | |||
| @@ -3,7 +3,6 @@ import logging | |||
| from typing import Any, Dict, List, Union, Optional | |||
| from langchain.callbacks.base import BaseCallbackHandler | |||
| from langchain.schema import AgentAction, AgentFinish, LLMResult | |||
| from core.callback_handler.entity.dataset_query import DatasetQueryObj | |||
| from core.conversation_message_task import ConversationMessageTask | |||
| @@ -11,6 +10,7 @@ from core.conversation_message_task import ConversationMessageTask | |||
| class DatasetToolCallbackHandler(BaseCallbackHandler): | |||
| """Callback Handler that prints to std out.""" | |||
| raise_error: bool = True | |||
| def __init__(self, conversation_message_task: ConversationMessageTask) -> None: | |||
| """Initialize callback handler.""" | |||
| @@ -66,52 +66,3 @@ class DatasetToolCallbackHandler(BaseCallbackHandler): | |||
| ) -> None: | |||
| """Do nothing.""" | |||
| logging.error(error) | |||
| def on_chain_start( | |||
| self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any | |||
| ) -> None: | |||
| pass | |||
| def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | |||
| pass | |||
| def on_chain_error( | |||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |||
| ) -> None: | |||
| pass | |||
| def on_llm_start( | |||
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |||
| ) -> None: | |||
| pass | |||
| def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |||
| pass | |||
| def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |||
| """Do nothing.""" | |||
| pass | |||
| def on_llm_error( | |||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |||
| ) -> None: | |||
| logging.error(error) | |||
| def on_agent_action( | |||
| self, action: AgentAction, color: Optional[str] = None, **kwargs: Any | |||
| ) -> Any: | |||
| pass | |||
| def on_text( | |||
| self, | |||
| text: str, | |||
| color: Optional[str] = None, | |||
| end: str = "", | |||
| **kwargs: Optional[str], | |||
| ) -> None: | |||
| """Run on additional input from chains and agents.""" | |||
| pass | |||
| def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: | |||
| """Run on agent end.""" | |||
| pass | |||
| @@ -1,39 +1,26 @@ | |||
| from llama_index import Response | |||
| from typing import List | |||
| from langchain.schema import Document | |||
| from extensions.ext_database import db | |||
| from models.dataset import DocumentSegment | |||
| class IndexToolCallbackHandler: | |||
| def __init__(self) -> None: | |||
| self._response = None | |||
| @property | |||
| def response(self) -> Response: | |||
| return self._response | |||
| def on_tool_end(self, response: Response) -> None: | |||
| """Handle tool end.""" | |||
| self._response = response | |||
| class DatasetIndexToolCallbackHandler(IndexToolCallbackHandler): | |||
| class DatasetIndexToolCallbackHandler: | |||
| """Callback handler for dataset tool.""" | |||
| def __init__(self, dataset_id: str) -> None: | |||
| super().__init__() | |||
| self.dataset_id = dataset_id | |||
| def on_tool_end(self, response: Response) -> None: | |||
| def on_tool_end(self, documents: List[Document]) -> None: | |||
| """Handle tool end.""" | |||
| for node in response.source_nodes: | |||
| index_node_id = node.node.doc_id | |||
| for document in documents: | |||
| doc_id = document.metadata['doc_id'] | |||
| # add hit count to document segment | |||
| db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.dataset_id == self.dataset_id, | |||
| DocumentSegment.index_node_id == index_node_id | |||
| DocumentSegment.index_node_id == doc_id | |||
| ).update( | |||
| {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, | |||
| synchronize_session=False | |||
| @@ -3,7 +3,7 @@ import time | |||
| from typing import Any, Dict, List, Union, Optional | |||
| from langchain.callbacks.base import BaseCallbackHandler | |||
| from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage | |||
| from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage, BaseMessage | |||
| from core.callback_handler.entity.llm_message import LLMMessage | |||
| from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException | |||
| @@ -12,6 +12,7 @@ from core.llm.streamable_open_ai import StreamableOpenAI | |||
| class LLMCallbackHandler(BaseCallbackHandler): | |||
| raise_error: bool = True | |||
| def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI], | |||
| conversation_message_task: ConversationMessageTask): | |||
| @@ -25,41 +26,41 @@ class LLMCallbackHandler(BaseCallbackHandler): | |||
| """Whether to call verbose callbacks even if verbose is False.""" | |||
| return True | |||
| def on_chat_model_start( | |||
| self, | |||
| serialized: Dict[str, Any], | |||
| messages: List[List[BaseMessage]], | |||
| **kwargs: Any | |||
| ) -> Any: | |||
| self.start_at = time.perf_counter() | |||
| real_prompts = [] | |||
| for message in messages[0]: | |||
| if message.type == 'human': | |||
| role = 'user' | |||
| elif message.type == 'ai': | |||
| role = 'assistant' | |||
| else: | |||
| role = 'system' | |||
| real_prompts.append({ | |||
| "role": role, | |||
| "text": message.content | |||
| }) | |||
| self.llm_message.prompt = real_prompts | |||
| self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0]) | |||
| def on_llm_start( | |||
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |||
| ) -> None: | |||
| self.start_at = time.perf_counter() | |||
| if 'Chat' in serialized['name']: | |||
| real_prompts = [] | |||
| messages = [] | |||
| for prompt in prompts: | |||
| role, content = prompt.split(': ', maxsplit=1) | |||
| if role == 'human': | |||
| role = 'user' | |||
| message = HumanMessage(content=content) | |||
| elif role == 'ai': | |||
| role = 'assistant' | |||
| message = AIMessage(content=content) | |||
| else: | |||
| message = SystemMessage(content=content) | |||
| real_prompt = { | |||
| "role": role, | |||
| "text": content | |||
| } | |||
| real_prompts.append(real_prompt) | |||
| messages.append(message) | |||
| self.llm_message.prompt = real_prompts | |||
| self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages) | |||
| else: | |||
| self.llm_message.prompt = [{ | |||
| "role": 'user', | |||
| "text": prompts[0] | |||
| }] | |||
| self.llm_message.prompt = [{ | |||
| "role": 'user', | |||
| "text": prompts[0] | |||
| }] | |||
| self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0]) | |||
| self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0]) | |||
| def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |||
| end_at = time.perf_counter() | |||
| @@ -95,58 +96,3 @@ class LLMCallbackHandler(BaseCallbackHandler): | |||
| self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True) | |||
| else: | |||
| logging.error(error) | |||
| def on_chain_start( | |||
| self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any | |||
| ) -> None: | |||
| pass | |||
| def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | |||
| pass | |||
| def on_chain_error( | |||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |||
| ) -> None: | |||
| pass | |||
| def on_tool_start( | |||
| self, | |||
| serialized: Dict[str, Any], | |||
| input_str: str, | |||
| **kwargs: Any, | |||
| ) -> None: | |||
| pass | |||
| def on_agent_action( | |||
| self, action: AgentAction, color: Optional[str] = None, **kwargs: Any | |||
| ) -> Any: | |||
| pass | |||
| def on_tool_end( | |||
| self, | |||
| output: str, | |||
| color: Optional[str] = None, | |||
| observation_prefix: Optional[str] = None, | |||
| llm_prefix: Optional[str] = None, | |||
| **kwargs: Any, | |||
| ) -> None: | |||
| pass | |||
| def on_tool_error( | |||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |||
| ) -> None: | |||
| pass | |||
| def on_text( | |||
| self, | |||
| text: str, | |||
| color: Optional[str] = None, | |||
| end: str = "", | |||
| **kwargs: Optional[str], | |||
| ) -> None: | |||
| pass | |||
| def on_agent_finish( | |||
| self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any | |||
| ) -> None: | |||
| pass | |||
| @@ -1,10 +1,9 @@ | |||
| import logging | |||
| import time | |||
| from typing import Any, Dict, List, Union, Optional | |||
| from typing import Any, Dict, Union | |||
| from langchain.callbacks.base import BaseCallbackHandler | |||
| from langchain.schema import AgentAction, AgentFinish, LLMResult | |||
| from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler | |||
| from core.callback_handler.entity.chain_result import ChainResult | |||
| @@ -14,6 +13,7 @@ from core.conversation_message_task import ConversationMessageTask | |||
| class MainChainGatherCallbackHandler(BaseCallbackHandler): | |||
| """Callback Handler that prints to std out.""" | |||
| raise_error: bool = True | |||
| def __init__(self, conversation_message_task: ConversationMessageTask) -> None: | |||
| """Initialize callback handler.""" | |||
| @@ -50,13 +50,15 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler): | |||
| ) -> None: | |||
| """Print out that we are entering a chain.""" | |||
| if not self._current_chain_result: | |||
| self._current_chain_result = ChainResult( | |||
| type=serialized['name'], | |||
| prompt=inputs, | |||
| started_at=time.perf_counter() | |||
| ) | |||
| self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result) | |||
| self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message | |||
| chain_type = serialized['id'][-1] | |||
| if chain_type: | |||
| self._current_chain_result = ChainResult( | |||
| type=chain_type, | |||
| prompt=inputs, | |||
| started_at=time.perf_counter() | |||
| ) | |||
| self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result) | |||
| self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message | |||
| def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | |||
| """Print out that we finished a chain.""" | |||
| @@ -74,64 +76,4 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler): | |||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |||
| ) -> None: | |||
| logging.error(error) | |||
| self.clear_chain_results() | |||
| def on_llm_start( | |||
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |||
| ) -> None: | |||
| pass | |||
| def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |||
| pass | |||
| def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |||
| """Do nothing.""" | |||
| pass | |||
| def on_llm_error( | |||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |||
| ) -> None: | |||
| logging.error(error) | |||
| def on_tool_start( | |||
| self, | |||
| serialized: Dict[str, Any], | |||
| input_str: str, | |||
| **kwargs: Any, | |||
| ) -> None: | |||
| pass | |||
| def on_agent_action( | |||
| self, action: AgentAction, color: Optional[str] = None, **kwargs: Any | |||
| ) -> Any: | |||
| pass | |||
| def on_tool_end( | |||
| self, | |||
| output: str, | |||
| color: Optional[str] = None, | |||
| observation_prefix: Optional[str] = None, | |||
| llm_prefix: Optional[str] = None, | |||
| **kwargs: Any, | |||
| ) -> None: | |||
| pass | |||
| def on_tool_error( | |||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |||
| ) -> None: | |||
| """Do nothing.""" | |||
| logging.error(error) | |||
| def on_text( | |||
| self, | |||
| text: str, | |||
| color: Optional[str] = None, | |||
| end: str = "", | |||
| **kwargs: Optional[str], | |||
| ) -> None: | |||
| """Run on additional input from chains and agents.""" | |||
| pass | |||
| def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: | |||
| """Run on agent end.""" | |||
| pass | |||
| self.clear_chain_results() | |||
| @@ -1,9 +1,10 @@ | |||
| import os | |||
| import sys | |||
| from typing import Any, Dict, List, Optional, Union | |||
| from langchain.callbacks.base import BaseCallbackHandler | |||
| from langchain.input import print_text | |||
| from langchain.schema import AgentAction, AgentFinish, LLMResult | |||
| from langchain.schema import AgentAction, AgentFinish, LLMResult, BaseMessage | |||
| class DifyStdOutCallbackHandler(BaseCallbackHandler): | |||
| @@ -13,17 +14,23 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): | |||
| """Initialize callback handler.""" | |||
| self.color = color | |||
| def on_chat_model_start( | |||
| self, | |||
| serialized: Dict[str, Any], | |||
| messages: List[List[BaseMessage]], | |||
| **kwargs: Any | |||
| ) -> Any: | |||
| print_text("\n[on_chat_model_start]\n", color='blue') | |||
| for sub_messages in messages: | |||
| for sub_message in sub_messages: | |||
| print_text(str(sub_message) + "\n", color='blue') | |||
| def on_llm_start( | |||
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |||
| ) -> None: | |||
| """Print out the prompts.""" | |||
| print_text("\n[on_llm_start]\n", color='blue') | |||
| if 'Chat' in serialized['name']: | |||
| for prompt in prompts: | |||
| print_text(prompt + "\n", color='blue') | |||
| else: | |||
| print_text(prompts[0] + "\n", color='blue') | |||
| print_text(prompts[0] + "\n", color='blue') | |||
| def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |||
| """Do nothing.""" | |||
| @@ -44,8 +51,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): | |||
| self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any | |||
| ) -> None: | |||
| """Print out that we are entering a chain.""" | |||
| class_name = serialized["name"] | |||
| print_text("\n[on_chain_start]\nChain: " + class_name + "\nInputs: " + str(inputs) + "\n", color='pink') | |||
| chain_type = serialized['id'][-1] | |||
| print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink') | |||
| def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | |||
| """Print out that we finished a chain.""" | |||
| @@ -117,6 +124,26 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): | |||
| """Run on agent end.""" | |||
| print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n") | |||
| @property | |||
| def ignore_llm(self) -> bool: | |||
| """Whether to ignore LLM callbacks.""" | |||
| return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' | |||
| @property | |||
| def ignore_chain(self) -> bool: | |||
| """Whether to ignore chain callbacks.""" | |||
| return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' | |||
| @property | |||
| def ignore_agent(self) -> bool: | |||
| """Whether to ignore agent callbacks.""" | |||
| return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' | |||
| @property | |||
| def ignore_chat_model(self) -> bool: | |||
| """Whether to ignore chat model callbacks.""" | |||
| return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' | |||
| class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler): | |||
| """Callback handler for streaming. Only works with LLMs that support streaming.""" | |||
| @@ -1,7 +1,5 @@ | |||
| from typing import Optional | |||
| from langchain.callbacks import CallbackManager | |||
| from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler | |||
| from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain | |||
| from core.chain.tool_chain import ToolChain | |||
| @@ -14,7 +12,7 @@ class ChainBuilder: | |||
| tool=tool, | |||
| input_key=kwargs.get('input_key', 'input'), | |||
| output_key=kwargs.get('output_key', 'tool_output'), | |||
| callback_manager=CallbackManager([DifyStdOutCallbackHandler()]) | |||
| callbacks=[DifyStdOutCallbackHandler()] | |||
| ) | |||
| @classmethod | |||
| @@ -27,7 +25,7 @@ class ChainBuilder: | |||
| sensitive_words=sensitive_words.split(","), | |||
| canned_response=tool_config.get("canned_response", ''), | |||
| output_key="sensitive_word_avoidance_output", | |||
| callback_manager=CallbackManager([DifyStdOutCallbackHandler()]), | |||
| callbacks=[DifyStdOutCallbackHandler()], | |||
| **kwargs | |||
| ) | |||
| @@ -1,15 +1,16 @@ | |||
| """Base classes for LLM-powered router chains.""" | |||
| from __future__ import annotations | |||
| import json | |||
| from typing import Any, Dict, List, Optional, Type, cast, NamedTuple | |||
| from langchain.base_language import BaseLanguageModel | |||
| from langchain.callbacks.manager import CallbackManagerForChainRun | |||
| from langchain.chains.base import Chain | |||
| from pydantic import root_validator | |||
| from langchain.chains import LLMChain | |||
| from langchain.prompts import BasePromptTemplate | |||
| from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel | |||
| from langchain.schema import BaseOutputParser, OutputParserException | |||
| from libs.json_in_md_parser import parse_and_check_json_markdown | |||
| @@ -51,8 +52,9 @@ class LLMRouterChain(Chain): | |||
| raise ValueError | |||
| def _call( | |||
| self, | |||
| inputs: Dict[str, Any] | |||
| self, | |||
| inputs: Dict[str, Any], | |||
| run_manager: Optional[CallbackManagerForChainRun] = None, | |||
| ) -> Dict[str, Any]: | |||
| output = cast( | |||
| Dict[str, Any], | |||
| @@ -1,11 +1,9 @@ | |||
| from typing import Optional, List | |||
| from typing import Optional, List, cast | |||
| from langchain.callbacks import SharedCallbackManager, CallbackManager | |||
| from langchain.chains import SequentialChain | |||
| from langchain.chains.base import Chain | |||
| from langchain.memory.chat_memory import BaseChatMemory | |||
| from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler | |||
| from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler | |||
| from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler | |||
| from core.chain.chain_builder import ChainBuilder | |||
| @@ -18,6 +16,7 @@ from models.dataset import Dataset | |||
| class MainChainBuilder: | |||
| @classmethod | |||
| def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory], | |||
| rest_tokens: int, | |||
| conversation_message_task: ConversationMessageTask): | |||
| first_input_key = "input" | |||
| final_output_key = "output" | |||
| @@ -30,6 +29,7 @@ class MainChainBuilder: | |||
| tool_chains, chains_output_key = cls.get_agent_chains( | |||
| tenant_id=tenant_id, | |||
| agent_mode=agent_mode, | |||
| rest_tokens=rest_tokens, | |||
| memory=memory, | |||
| conversation_message_task=conversation_message_task | |||
| ) | |||
| @@ -42,9 +42,8 @@ class MainChainBuilder: | |||
| return None | |||
| for chain in chains: | |||
| # do not add handler into singleton callback manager | |||
| if not isinstance(chain.callback_manager, SharedCallbackManager): | |||
| chain.callback_manager.add_handler(chain_callback_handler) | |||
| chain = cast(Chain, chain) | |||
| chain.callbacks.append(chain_callback_handler) | |||
| # build main chain | |||
| overall_chain = SequentialChain( | |||
| @@ -57,7 +56,9 @@ class MainChainBuilder: | |||
| return overall_chain | |||
| @classmethod | |||
| def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory], | |||
| def get_agent_chains(cls, tenant_id: str, agent_mode: dict, | |||
| rest_tokens: int, | |||
| memory: Optional[BaseChatMemory], | |||
| conversation_message_task: ConversationMessageTask): | |||
| # agent mode | |||
| chains = [] | |||
| @@ -93,7 +94,8 @@ class MainChainBuilder: | |||
| tenant_id=tenant_id, | |||
| datasets=datasets, | |||
| conversation_message_task=conversation_message_task, | |||
| callback_manager=CallbackManager([DifyStdOutCallbackHandler()]) | |||
| rest_tokens=rest_tokens, | |||
| callbacks=[DifyStdOutCallbackHandler()] | |||
| ) | |||
| chains.append(multi_dataset_router_chain) | |||
| @@ -1,9 +1,9 @@ | |||
| import math | |||
| from typing import Mapping, List, Dict, Any, Optional | |||
| from langchain import LLMChain, PromptTemplate, ConversationChain | |||
| from langchain.callbacks import CallbackManager | |||
| from langchain import PromptTemplate | |||
| from langchain.callbacks.manager import CallbackManagerForChainRun | |||
| from langchain.chains.base import Chain | |||
| from langchain.schema import BaseLanguageModel | |||
| from pydantic import Extra | |||
| from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler | |||
| @@ -11,10 +11,11 @@ from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHan | |||
| from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser | |||
| from core.conversation_message_task import ConversationMessageTask | |||
| from core.llm.llm_builder import LLMBuilder | |||
| from core.tool.dataset_tool_builder import DatasetToolBuilder | |||
| from core.tool.llama_index_tool import EnhanceLlamaIndexTool | |||
| from models.dataset import Dataset | |||
| from core.tool.dataset_index_tool import DatasetTool | |||
| from models.dataset import Dataset, DatasetProcessRule | |||
| DEFAULT_K = 2 | |||
| CONTEXT_TOKENS_PERCENT = 0.3 | |||
| MULTI_PROMPT_ROUTER_TEMPLATE = """ | |||
| Given a raw text input to a language model select the model prompt best suited for \ | |||
| the input. You will be given the names of the available prompts and a description of \ | |||
| @@ -52,7 +53,7 @@ class MultiDatasetRouterChain(Chain): | |||
| router_chain: LLMRouterChain | |||
| """Chain for deciding a destination chain and the input to it.""" | |||
| dataset_tools: Mapping[str, EnhanceLlamaIndexTool] | |||
| dataset_tools: Mapping[str, DatasetTool] | |||
| """Map of name to candidate chains that inputs can be routed to.""" | |||
| class Config: | |||
| @@ -79,41 +80,56 @@ class MultiDatasetRouterChain(Chain): | |||
| tenant_id: str, | |||
| datasets: List[Dataset], | |||
| conversation_message_task: ConversationMessageTask, | |||
| rest_tokens: int, | |||
| **kwargs: Any, | |||
| ): | |||
| """Convenience constructor for instantiating from destination prompts.""" | |||
| llm_callback_manager = CallbackManager([DifyStdOutCallbackHandler()]) | |||
| llm = LLMBuilder.to_llm( | |||
| tenant_id=tenant_id, | |||
| model_name='gpt-3.5-turbo', | |||
| temperature=0, | |||
| max_tokens=1024, | |||
| callback_manager=llm_callback_manager | |||
| callbacks=[DifyStdOutCallbackHandler()] | |||
| ) | |||
| destinations = ["{}: {}".format(d.id, d.description.replace('\n', ' ') if d.description | |||
| destinations = ["[[{}]]: {}".format(d.id, d.description.replace('\n', ' ') if d.description | |||
| else ('useful for when you want to answer queries about the ' + d.name)) | |||
| for d in datasets] | |||
| destinations_str = "\n".join(destinations) | |||
| router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format( | |||
| destinations=destinations_str | |||
| ) | |||
| router_prompt = PromptTemplate( | |||
| template=router_template, | |||
| input_variables=["input"], | |||
| output_parser=RouterOutputParser(), | |||
| ) | |||
| router_chain = LLMRouterChain.from_llm(llm, router_prompt) | |||
| dataset_tools = {} | |||
| for dataset in datasets: | |||
| dataset_tool = DatasetToolBuilder.build_dataset_tool( | |||
| # fulfill description when it is empty | |||
| if dataset.available_document_count == 0 or dataset.available_document_count == 0: | |||
| continue | |||
| description = dataset.description | |||
| if not description: | |||
| description = 'useful for when you want to answer queries about the ' + dataset.name | |||
| k = cls._dynamic_calc_retrieve_k(dataset, rest_tokens) | |||
| if k == 0: | |||
| continue | |||
| dataset_tool = DatasetTool( | |||
| name=f"dataset-{dataset.id}", | |||
| description=description, | |||
| k=k, | |||
| dataset=dataset, | |||
| response_mode='no_synthesizer', # "compact" | |||
| callback_handler=DatasetToolCallbackHandler(conversation_message_task) | |||
| callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()] | |||
| ) | |||
| if dataset_tool: | |||
| dataset_tools[dataset.id] = dataset_tool | |||
| dataset_tools[str(dataset.id)] = dataset_tool | |||
| return cls( | |||
| router_chain=router_chain, | |||
| @@ -121,9 +137,39 @@ class MultiDatasetRouterChain(Chain): | |||
| **kwargs, | |||
| ) | |||
| @classmethod | |||
| def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int: | |||
| processing_rule = dataset.latest_process_rule | |||
| if not processing_rule: | |||
| return DEFAULT_K | |||
| if processing_rule.mode == "custom": | |||
| rules = processing_rule.rules_dict | |||
| if not rules: | |||
| return DEFAULT_K | |||
| segmentation = rules["segmentation"] | |||
| segment_max_tokens = segmentation["max_tokens"] | |||
| else: | |||
| segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'] | |||
| # when rest_tokens is less than default context tokens | |||
| if rest_tokens < segment_max_tokens * DEFAULT_K: | |||
| return rest_tokens // segment_max_tokens | |||
| context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT) | |||
| # when context_limit_tokens is less than default context tokens, use default_k | |||
| if context_limit_tokens <= segment_max_tokens * DEFAULT_K: | |||
| return DEFAULT_K | |||
| # Expand the k value when there's still some room left in the 30% rest tokens space | |||
| return context_limit_tokens // segment_max_tokens | |||
| def _call( | |||
| self, | |||
| inputs: Dict[str, Any] | |||
| inputs: Dict[str, Any], | |||
| run_manager: Optional[CallbackManagerForChainRun] = None, | |||
| ) -> Dict[str, Any]: | |||
| if len(self.dataset_tools) == 0: | |||
| return {"text": ''} | |||
| @@ -1,5 +1,6 @@ | |||
| from typing import List, Dict | |||
| from typing import List, Dict, Optional, Any | |||
| from langchain.callbacks.manager import CallbackManagerForChainRun | |||
| from langchain.chains.base import Chain | |||
| @@ -36,7 +37,11 @@ class SensitiveWordAvoidanceChain(Chain): | |||
| return self.canned_response | |||
| return text | |||
| def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: | |||
| def _call( | |||
| self, | |||
| inputs: Dict[str, Any], | |||
| run_manager: Optional[CallbackManagerForChainRun] = None, | |||
| ) -> Dict[str, Any]: | |||
| text = inputs[self.input_key] | |||
| output = self._check_sensitive_word(text) | |||
| return {self.output_key: output} | |||
| @@ -1,5 +1,6 @@ | |||
| from typing import List, Dict | |||
| from typing import List, Dict, Optional, Any | |||
| from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun | |||
| from langchain.chains.base import Chain | |||
| from langchain.tools import BaseTool | |||
| @@ -30,12 +31,20 @@ class ToolChain(Chain): | |||
| """ | |||
| return [self.output_key] | |||
| def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: | |||
| def _call( | |||
| self, | |||
| inputs: Dict[str, Any], | |||
| run_manager: Optional[CallbackManagerForChainRun] = None, | |||
| ) -> Dict[str, Any]: | |||
| input = inputs[self.input_key] | |||
| output = self.tool.run(input, self.verbose) | |||
| return {self.output_key: output} | |||
| async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: | |||
| async def _acall( | |||
| self, | |||
| inputs: Dict[str, Any], | |||
| run_manager: Optional[AsyncCallbackManagerForChainRun] = None, | |||
| ) -> Dict[str, Any]: | |||
| """Run the logic of this chain and return the output.""" | |||
| input = inputs[self.input_key] | |||
| output = await self.tool.arun(input, self.verbose) | |||
| @@ -1,17 +1,18 @@ | |||
| import logging | |||
| from typing import Optional, List, Union, Tuple | |||
| from langchain.callbacks import CallbackManager | |||
| from langchain.base_language import BaseLanguageModel | |||
| from langchain.callbacks.base import BaseCallbackHandler | |||
| from langchain.chat_models.base import BaseChatModel | |||
| from langchain.llms import BaseLLM | |||
| from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage | |||
| from langchain.schema import BaseMessage, HumanMessage | |||
| from requests.exceptions import ChunkedEncodingError | |||
| from core.constant import llm_constant | |||
| from core.callback_handler.llm_callback_handler import LLMCallbackHandler | |||
| from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \ | |||
| DifyStdOutCallbackHandler | |||
| from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, PubHandler | |||
| from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException | |||
| from core.llm.error import LLMBadRequestError | |||
| from core.llm.llm_builder import LLMBuilder | |||
| from core.chain.main_chain_builder import MainChainBuilder | |||
| @@ -34,8 +35,6 @@ class Completion: | |||
| """ | |||
| errors: ProviderTokenNotInitError | |||
| """ | |||
| cls.validate_query_tokens(app.tenant_id, app_model_config, query) | |||
| memory = None | |||
| if conversation: | |||
| # get memory of conversation (read-only) | |||
| @@ -48,6 +47,14 @@ class Completion: | |||
| inputs = conversation.inputs | |||
| rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens( | |||
| mode=app.mode, | |||
| tenant_id=app.tenant_id, | |||
| app_model_config=app_model_config, | |||
| query=query, | |||
| inputs=inputs | |||
| ) | |||
| conversation_message_task = ConversationMessageTask( | |||
| task_id=task_id, | |||
| app=app, | |||
| @@ -64,6 +71,7 @@ class Completion: | |||
| main_chain = MainChainBuilder.to_langchain_components( | |||
| tenant_id=app.tenant_id, | |||
| agent_mode=app_model_config.agent_mode_dict, | |||
| rest_tokens=rest_tokens_for_context_and_memory, | |||
| memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None, | |||
| conversation_message_task=conversation_message_task | |||
| ) | |||
| @@ -115,7 +123,7 @@ class Completion: | |||
| memory=memory | |||
| ) | |||
| final_llm.callback_manager = cls.get_llm_callback_manager(final_llm, streaming, conversation_message_task) | |||
| final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task) | |||
| cls.recale_llm_max_tokens( | |||
| final_llm=final_llm, | |||
| @@ -247,16 +255,14 @@ And answer according to the language of the user's question. | |||
| return messages, ['\nHuman:'] | |||
| @classmethod | |||
| def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI], | |||
| streaming: bool, | |||
| conversation_message_task: ConversationMessageTask) -> CallbackManager: | |||
| def get_llm_callbacks(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI], | |||
| streaming: bool, | |||
| conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]: | |||
| llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task) | |||
| if streaming: | |||
| callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()] | |||
| return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()] | |||
| else: | |||
| callback_handlers = [llm_callback_handler, DifyStdOutCallbackHandler()] | |||
| return CallbackManager(callback_handlers) | |||
| return [llm_callback_handler, DifyStdOutCallbackHandler()] | |||
| @classmethod | |||
| def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, | |||
| @@ -293,7 +299,8 @@ And answer according to the language of the user's question. | |||
| return memory | |||
| @classmethod | |||
| def validate_query_tokens(cls, tenant_id: str, app_model_config: AppModelConfig, query: str): | |||
| def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig, | |||
| query: str, inputs: dict) -> int: | |||
| llm = LLMBuilder.to_llm_from_model( | |||
| tenant_id=tenant_id, | |||
| model=app_model_config.model_dict | |||
| @@ -302,8 +309,26 @@ And answer according to the language of the user's question. | |||
| model_limited_tokens = llm_constant.max_context_token_length[llm.model_name] | |||
| max_tokens = llm.max_tokens | |||
| if model_limited_tokens - max_tokens - llm.get_num_tokens(query) < 0: | |||
| raise LLMBadRequestError("Query is too long") | |||
| # get prompt without memory and context | |||
| prompt, _ = cls.get_main_llm_prompt( | |||
| mode=mode, | |||
| llm=llm, | |||
| pre_prompt=app_model_config.pre_prompt, | |||
| query=query, | |||
| inputs=inputs, | |||
| chain_output=None, | |||
| memory=None | |||
| ) | |||
| prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \ | |||
| else llm.get_num_tokens_from_messages(prompt) | |||
| rest_tokens = model_limited_tokens - max_tokens - prompt_tokens | |||
| if rest_tokens < 0: | |||
| raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, " | |||
| "or shrink the max token, or switch to a llm with a larger token limit size.") | |||
| return rest_tokens | |||
| @classmethod | |||
| def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI], | |||
| @@ -360,7 +385,7 @@ And answer according to the language of the user's question. | |||
| streaming=streaming | |||
| ) | |||
| llm.callback_manager = cls.get_llm_callback_manager(llm, streaming, conversation_message_task) | |||
| llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task) | |||
| cls.recale_llm_max_tokens( | |||
| final_llm=llm, | |||
| @@ -293,12 +293,12 @@ class PubHandler: | |||
| if not user: | |||
| raise ValueError("user is required") | |||
| user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id | |||
| user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id) | |||
| return "generate_result:{}-{}".format(user_str, task_id) | |||
| @classmethod | |||
| def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str): | |||
| user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id | |||
| user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id) | |||
| return "generate_result_stopped:{}-{}".format(user_str, task_id) | |||
| def pub_text(self, text: str): | |||
| @@ -306,10 +306,10 @@ class PubHandler: | |||
| 'event': 'message', | |||
| 'data': { | |||
| 'task_id': self._task_id, | |||
| 'message_id': self._message.id, | |||
| 'message_id': str(self._message.id), | |||
| 'text': text, | |||
| 'mode': self._conversation.mode, | |||
| 'conversation_id': self._conversation.id | |||
| 'conversation_id': str(self._conversation.id) | |||
| } | |||
| } | |||
| @@ -0,0 +1,43 @@ | |||
| import tempfile | |||
| from pathlib import Path | |||
| from typing import List, Union | |||
| from langchain.document_loaders import TextLoader, Docx2txtLoader | |||
| from langchain.schema import Document | |||
| from core.data_loader.loader.csv import CSVLoader | |||
| from core.data_loader.loader.excel import ExcelLoader | |||
| from core.data_loader.loader.html import HTMLLoader | |||
| from core.data_loader.loader.markdown import MarkdownLoader | |||
| from core.data_loader.loader.pdf import PdfLoader | |||
| from extensions.ext_storage import storage | |||
| from models.model import UploadFile | |||
| class FileExtractor: | |||
| @classmethod | |||
| def load(cls, upload_file: UploadFile, return_text: bool = False) -> Union[List[Document] | str]: | |||
| with tempfile.TemporaryDirectory() as temp_dir: | |||
| suffix = Path(upload_file.key).suffix | |||
| file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" | |||
| storage.download(upload_file.key, file_path) | |||
| input_file = Path(file_path) | |||
| delimiter = '\n' | |||
| if input_file.suffix == '.xlsx': | |||
| loader = ExcelLoader(file_path) | |||
| elif input_file.suffix == '.pdf': | |||
| loader = PdfLoader(file_path, upload_file=upload_file) | |||
| elif input_file.suffix in ['.md', '.markdown']: | |||
| loader = MarkdownLoader(file_path, autodetect_encoding=True) | |||
| elif input_file.suffix in ['.htm', '.html']: | |||
| loader = HTMLLoader(file_path) | |||
| elif input_file.suffix == '.docx': | |||
| loader = Docx2txtLoader(file_path) | |||
| elif input_file.suffix == '.csv': | |||
| loader = CSVLoader(file_path, autodetect_encoding=True) | |||
| else: | |||
| # txt | |||
| loader = TextLoader(file_path, autodetect_encoding=True) | |||
| return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load() | |||
| @@ -0,0 +1,67 @@ | |||
| import logging | |||
| from typing import Optional, Dict, List | |||
| from langchain.document_loaders import CSVLoader as LCCSVLoader | |||
| from langchain.document_loaders.helpers import detect_file_encodings | |||
| from models.dataset import Document | |||
| logger = logging.getLogger(__name__) | |||
| class CSVLoader(LCCSVLoader): | |||
| def __init__( | |||
| self, | |||
| file_path: str, | |||
| source_column: Optional[str] = None, | |||
| csv_args: Optional[Dict] = None, | |||
| encoding: Optional[str] = None, | |||
| autodetect_encoding: bool = True, | |||
| ): | |||
| self.file_path = file_path | |||
| self.source_column = source_column | |||
| self.encoding = encoding | |||
| self.csv_args = csv_args or {} | |||
| self.autodetect_encoding = autodetect_encoding | |||
| def load(self) -> List[Document]: | |||
| """Load data into document objects.""" | |||
| try: | |||
| with open(self.file_path, newline="", encoding=self.encoding) as csvfile: | |||
| docs = self._read_from_file(csvfile) | |||
| except UnicodeDecodeError as e: | |||
| if self.autodetect_encoding: | |||
| detected_encodings = detect_file_encodings(self.file_path) | |||
| for encoding in detected_encodings: | |||
| logger.debug("Trying encoding: ", encoding.encoding) | |||
| try: | |||
| with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile: | |||
| docs = self._read_from_file(csvfile) | |||
| break | |||
| except UnicodeDecodeError: | |||
| continue | |||
| else: | |||
| raise RuntimeError(f"Error loading {self.file_path}") from e | |||
| return docs | |||
| def _read_from_file(self, csvfile): | |||
| docs = [] | |||
| csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore | |||
| for i, row in enumerate(csv_reader): | |||
| content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items()) | |||
| try: | |||
| source = ( | |||
| row[self.source_column] | |||
| if self.source_column is not None | |||
| else '' | |||
| ) | |||
| except KeyError: | |||
| raise ValueError( | |||
| f"Source column '{self.source_column}' not found in CSV file." | |||
| ) | |||
| metadata = {"source": source, "row": i} | |||
| doc = Document(page_content=content, metadata=metadata) | |||
| docs.append(doc) | |||
| return docs | |||
| @@ -0,0 +1,43 @@ | |||
| import json | |||
| import logging | |||
| from typing import List | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.schema import Document | |||
| from openpyxl.reader.excel import load_workbook | |||
| logger = logging.getLogger(__name__) | |||
| class ExcelLoader(BaseLoader): | |||
| """Load xlxs files. | |||
| Args: | |||
| file_path: Path to the file to load. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| file_path: str | |||
| ): | |||
| """Initialize with file path.""" | |||
| self._file_path = file_path | |||
| def load(self) -> List[Document]: | |||
| data = [] | |||
| keys = [] | |||
| wb = load_workbook(filename=self._file_path, read_only=True) | |||
| # loop over all sheets | |||
| for sheet in wb: | |||
| for row in sheet.iter_rows(values_only=True): | |||
| if all(v is None for v in row): | |||
| continue | |||
| if keys == []: | |||
| keys = list(map(str, row)) | |||
| else: | |||
| row_dict = dict(zip(keys, row)) | |||
| row_dict = {k: v for k, v in row_dict.items() if v} | |||
| data.append(json.dumps(row_dict, ensure_ascii=False)) | |||
| return [Document(page_content='\n\n'.join(data))] | |||
| @@ -0,0 +1,35 @@ | |||
| import logging | |||
| from typing import List | |||
| from bs4 import BeautifulSoup | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.schema import Document | |||
| logger = logging.getLogger(__name__) | |||
| class HTMLLoader(BaseLoader): | |||
| """Load html files. | |||
| Args: | |||
| file_path: Path to the file to load. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| file_path: str | |||
| ): | |||
| """Initialize with file path.""" | |||
| self._file_path = file_path | |||
| def load(self) -> List[Document]: | |||
| return [Document(page_content=self._load_as_text())] | |||
| def _load_as_text(self) -> str: | |||
| with open(self._file_path, "rb") as fp: | |||
| soup = BeautifulSoup(fp, 'html.parser') | |||
| text = soup.get_text() | |||
| text = text.strip() if text else '' | |||
| return text | |||
| @@ -0,0 +1,134 @@ | |||
| import logging | |||
| import re | |||
| from typing import Optional, List, Tuple, cast | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.document_loaders.helpers import detect_file_encodings | |||
| from langchain.schema import Document | |||
| logger = logging.getLogger(__name__) | |||
| class MarkdownLoader(BaseLoader): | |||
| """Load md files. | |||
| Args: | |||
| file_path: Path to the file to load. | |||
| remove_hyperlinks: Whether to remove hyperlinks from the text. | |||
| remove_images: Whether to remove images from the text. | |||
| encoding: File encoding to use. If `None`, the file will be loaded | |||
| with the default system encoding. | |||
| autodetect_encoding: Whether to try to autodetect the file encoding | |||
| if the specified encoding fails. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| file_path: str, | |||
| remove_hyperlinks: bool = True, | |||
| remove_images: bool = True, | |||
| encoding: Optional[str] = None, | |||
| autodetect_encoding: bool = True, | |||
| ): | |||
| """Initialize with file path.""" | |||
| self._file_path = file_path | |||
| self._remove_hyperlinks = remove_hyperlinks | |||
| self._remove_images = remove_images | |||
| self._encoding = encoding | |||
| self._autodetect_encoding = autodetect_encoding | |||
| def load(self) -> List[Document]: | |||
| tups = self.parse_tups(self._file_path) | |||
| documents = [] | |||
| for header, value in tups: | |||
| value = value.strip() | |||
| if header is None: | |||
| documents.append(Document(page_content=value)) | |||
| else: | |||
| documents.append(Document(page_content=f"\n\n{header}\n{value}")) | |||
| return documents | |||
| def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]: | |||
| """Convert a markdown file to a dictionary. | |||
| The keys are the headers and the values are the text under each header. | |||
| """ | |||
| markdown_tups: List[Tuple[Optional[str], str]] = [] | |||
| lines = markdown_text.split("\n") | |||
| current_header = None | |||
| current_text = "" | |||
| for line in lines: | |||
| header_match = re.match(r"^#+\s", line) | |||
| if header_match: | |||
| if current_header is not None: | |||
| markdown_tups.append((current_header, current_text)) | |||
| current_header = line | |||
| current_text = "" | |||
| else: | |||
| current_text += line + "\n" | |||
| markdown_tups.append((current_header, current_text)) | |||
| if current_header is not None: | |||
| # pass linting, assert keys are defined | |||
| markdown_tups = [ | |||
| (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) | |||
| for key, value in markdown_tups | |||
| ] | |||
| else: | |||
| markdown_tups = [ | |||
| (key, re.sub("\n", "", value)) for key, value in markdown_tups | |||
| ] | |||
| return markdown_tups | |||
| def remove_images(self, content: str) -> str: | |||
| """Get a dictionary of a markdown file from its path.""" | |||
| pattern = r"!{1}\[\[(.*)\]\]" | |||
| content = re.sub(pattern, "", content) | |||
| return content | |||
| def remove_hyperlinks(self, content: str) -> str: | |||
| """Get a dictionary of a markdown file from its path.""" | |||
| pattern = r"\[(.*?)\]\((.*?)\)" | |||
| content = re.sub(pattern, r"\1", content) | |||
| return content | |||
| def parse_tups(self, filepath: str) -> List[Tuple[Optional[str], str]]: | |||
| """Parse file into tuples.""" | |||
| content = "" | |||
| try: | |||
| with open(filepath, "r", encoding=self._encoding) as f: | |||
| content = f.read() | |||
| except UnicodeDecodeError as e: | |||
| if self._autodetect_encoding: | |||
| detected_encodings = detect_file_encodings(filepath) | |||
| for encoding in detected_encodings: | |||
| logger.debug("Trying encoding: ", encoding.encoding) | |||
| try: | |||
| with open(filepath, encoding=encoding.encoding) as f: | |||
| content = f.read() | |||
| break | |||
| except UnicodeDecodeError: | |||
| continue | |||
| else: | |||
| raise RuntimeError(f"Error loading {filepath}") from e | |||
| except Exception as e: | |||
| raise RuntimeError(f"Error loading {filepath}") from e | |||
| if self._remove_hyperlinks: | |||
| content = self.remove_hyperlinks(content) | |||
| if self._remove_images: | |||
| content = self.remove_images(content) | |||
| return self.markdown_to_tups(content) | |||
| @@ -1,68 +1,162 @@ | |||
| """Notion reader.""" | |||
| import json | |||
| import logging | |||
| import os | |||
| from datetime import datetime | |||
| from typing import Any, Dict, List, Optional | |||
| from typing import List, Dict, Any, Optional | |||
| import requests # type: ignore | |||
| import requests | |||
| from flask import current_app | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.schema import Document | |||
| from llama_index.readers.base import BaseReader | |||
| from llama_index.readers.schema.base import Document | |||
| from extensions.ext_database import db | |||
| from models.dataset import Document as DocumentModel | |||
| from models.source import DataSourceBinding | |||
| logger = logging.getLogger(__name__) | |||
| INTEGRATION_TOKEN_NAME = "NOTION_INTEGRATION_TOKEN" | |||
| BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children" | |||
| DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query" | |||
| SEARCH_URL = "https://api.notion.com/v1/search" | |||
| RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}" | |||
| RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}" | |||
| HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3'] | |||
| logger = logging.getLogger(__name__) | |||
| # TODO: Notion DB reader coming soon! | |||
| class NotionPageReader(BaseReader): | |||
| """Notion Page reader. | |||
| Reads a set of Notion pages. | |||
| Args: | |||
| integration_token (str): Notion integration token. | |||
| """ | |||
| def __init__(self, integration_token: Optional[str] = None) -> None: | |||
| """Initialize with parameters.""" | |||
| if integration_token is None: | |||
| integration_token = os.getenv(INTEGRATION_TOKEN_NAME) | |||
| class NotionLoader(BaseLoader): | |||
| def __init__( | |||
| self, | |||
| notion_access_token: str, | |||
| notion_workspace_id: str, | |||
| notion_obj_id: str, | |||
| notion_page_type: str, | |||
| document_model: Optional[DocumentModel] = None | |||
| ): | |||
| self._document_model = document_model | |||
| self._notion_workspace_id = notion_workspace_id | |||
| self._notion_obj_id = notion_obj_id | |||
| self._notion_page_type = notion_page_type | |||
| self._notion_access_token = notion_access_token | |||
| if not self._notion_access_token: | |||
| integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN') | |||
| if integration_token is None: | |||
| raise ValueError( | |||
| "Must specify `integration_token` or set environment " | |||
| "variable `NOTION_INTEGRATION_TOKEN`." | |||
| ) | |||
| self.token = integration_token | |||
| self.headers = { | |||
| "Authorization": "Bearer " + self.token, | |||
| "Content-Type": "application/json", | |||
| "Notion-Version": "2022-06-28", | |||
| } | |||
| def _read_block(self, block_id: str, num_tabs: int = 0) -> str: | |||
| """Read a block.""" | |||
| done = False | |||
| self._notion_access_token = integration_token | |||
| @classmethod | |||
| def from_document(cls, document_model: DocumentModel): | |||
| data_source_info = document_model.data_source_info_dict | |||
| if not data_source_info or 'notion_page_id' not in data_source_info \ | |||
| or 'notion_workspace_id' not in data_source_info: | |||
| raise ValueError("no notion page found") | |||
| notion_workspace_id = data_source_info['notion_workspace_id'] | |||
| notion_obj_id = data_source_info['notion_page_id'] | |||
| notion_page_type = data_source_info['type'] | |||
| notion_access_token = cls._get_access_token(document_model.tenant_id, notion_workspace_id) | |||
| return cls( | |||
| notion_access_token=notion_access_token, | |||
| notion_workspace_id=notion_workspace_id, | |||
| notion_obj_id=notion_obj_id, | |||
| notion_page_type=notion_page_type, | |||
| document_model=document_model | |||
| ) | |||
| def load(self) -> List[Document]: | |||
| self.update_last_edited_time( | |||
| self._document_model | |||
| ) | |||
| text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type) | |||
| return text_docs | |||
| def _load_data_as_documents( | |||
| self, notion_obj_id: str, notion_page_type: str | |||
| ) -> List[Document]: | |||
| docs = [] | |||
| if notion_page_type == 'database': | |||
| # get all the pages in the database | |||
| page_text = self._get_notion_database_data(notion_obj_id) | |||
| docs.append(Document(page_content=page_text)) | |||
| elif notion_page_type == 'page': | |||
| page_text_list = self._get_notion_block_data(notion_obj_id) | |||
| for page_text in page_text_list: | |||
| docs.append(Document(page_content=page_text)) | |||
| else: | |||
| raise ValueError("notion page type not supported") | |||
| return docs | |||
| def _get_notion_database_data( | |||
| self, database_id: str, query_dict: Dict[str, Any] = {} | |||
| ) -> str: | |||
| """Get all the pages from a Notion database.""" | |||
| res = requests.post( | |||
| DATABASE_URL_TMPL.format(database_id=database_id), | |||
| headers={ | |||
| "Authorization": "Bearer " + self._notion_access_token, | |||
| "Content-Type": "application/json", | |||
| "Notion-Version": "2022-06-28", | |||
| }, | |||
| json=query_dict, | |||
| ) | |||
| data = res.json() | |||
| database_content_list = [] | |||
| if 'results' not in data or data["results"] is None: | |||
| return "" | |||
| for result in data["results"]: | |||
| properties = result['properties'] | |||
| data = {} | |||
| for property_name, property_value in properties.items(): | |||
| type = property_value['type'] | |||
| if type == 'multi_select': | |||
| value = [] | |||
| multi_select_list = property_value[type] | |||
| for multi_select in multi_select_list: | |||
| value.append(multi_select['name']) | |||
| elif type == 'rich_text' or type == 'title': | |||
| if len(property_value[type]) > 0: | |||
| value = property_value[type][0]['plain_text'] | |||
| else: | |||
| value = '' | |||
| elif type == 'select' or type == 'status': | |||
| if property_value[type]: | |||
| value = property_value[type]['name'] | |||
| else: | |||
| value = '' | |||
| else: | |||
| value = property_value[type] | |||
| data[property_name] = value | |||
| database_content_list.append(json.dumps(data, ensure_ascii=False)) | |||
| return "\n\n".join(database_content_list) | |||
| def _get_notion_block_data(self, page_id: str) -> List[str]: | |||
| result_lines_arr = [] | |||
| cur_block_id = block_id | |||
| while not done: | |||
| cur_block_id = page_id | |||
| while True: | |||
| block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) | |||
| query_dict: Dict[str, Any] = {} | |||
| res = requests.request( | |||
| "GET", block_url, headers=self.headers, json=query_dict | |||
| "GET", | |||
| block_url, | |||
| headers={ | |||
| "Authorization": "Bearer " + self._notion_access_token, | |||
| "Content-Type": "application/json", | |||
| "Notion-Version": "2022-06-28", | |||
| }, | |||
| json=query_dict | |||
| ) | |||
| data = res.json() | |||
| if 'results' not in data or data["results"] is None: | |||
| done = True | |||
| break | |||
| # current block's heading | |||
| heading = '' | |||
| for result in data["results"]: | |||
| result_type = result["type"] | |||
| @@ -71,6 +165,7 @@ class NotionPageReader(BaseReader): | |||
| if result_type == 'table': | |||
| result_block_id = result["id"] | |||
| text = self._read_table_rows(result_block_id) | |||
| text += "\n\n" | |||
| result_lines_arr.append(text) | |||
| else: | |||
| if "rich_text" in result_obj: | |||
| @@ -78,91 +173,53 @@ class NotionPageReader(BaseReader): | |||
| # skip if doesn't have text object | |||
| if "text" in rich_text: | |||
| text = rich_text["text"]["content"] | |||
| prefix = "\t" * num_tabs | |||
| cur_result_text_arr.append(prefix + text) | |||
| cur_result_text_arr.append(text) | |||
| if result_type in HEADING_TYPE: | |||
| heading = text | |||
| result_block_id = result["id"] | |||
| has_children = result["has_children"] | |||
| block_type = result["type"] | |||
| if has_children and block_type != 'child_page': | |||
| children_text = self._read_block( | |||
| result_block_id, num_tabs=num_tabs + 1 | |||
| result_block_id, num_tabs=1 | |||
| ) | |||
| cur_result_text_arr.append(children_text) | |||
| cur_result_text = "\n".join(cur_result_text_arr) | |||
| cur_result_text += "\n\n" | |||
| if result_type in HEADING_TYPE: | |||
| result_lines_arr.append(cur_result_text) | |||
| else: | |||
| result_lines_arr.append(f'{heading}\n{cur_result_text}') | |||
| if data["next_cursor"] is None: | |||
| done = True | |||
| break | |||
| else: | |||
| cur_block_id = data["next_cursor"] | |||
| result_lines = "\n".join(result_lines_arr) | |||
| return result_lines | |||
| def _read_table_rows(self, block_id: str) -> str: | |||
| """Read table rows.""" | |||
| done = False | |||
| result_lines_arr = [] | |||
| cur_block_id = block_id | |||
| while not done: | |||
| block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) | |||
| query_dict: Dict[str, Any] = {} | |||
| res = requests.request( | |||
| "GET", block_url, headers=self.headers, json=query_dict | |||
| ) | |||
| data = res.json() | |||
| # get table headers text | |||
| table_header_cell_texts = [] | |||
| tabel_header_cells = data["results"][0]['table_row']['cells'] | |||
| for tabel_header_cell in tabel_header_cells: | |||
| if tabel_header_cell: | |||
| for table_header_cell_text in tabel_header_cell: | |||
| text = table_header_cell_text["text"]["content"] | |||
| table_header_cell_texts.append(text) | |||
| # get table columns text and format | |||
| results = data["results"] | |||
| for i in range(len(results)-1): | |||
| column_texts = [] | |||
| tabel_column_cells = data["results"][i+1]['table_row']['cells'] | |||
| for j in range(len(tabel_column_cells)): | |||
| if tabel_column_cells[j]: | |||
| for table_column_cell_text in tabel_column_cells[j]: | |||
| column_text = table_column_cell_text["text"]["content"] | |||
| column_texts.append(f'{table_header_cell_texts[j]}:{column_text}') | |||
| cur_result_text = "\n".join(column_texts) | |||
| result_lines_arr.append(cur_result_text) | |||
| if data["next_cursor"] is None: | |||
| done = True | |||
| break | |||
| else: | |||
| cur_block_id = data["next_cursor"] | |||
| return result_lines_arr | |||
| result_lines = "\n".join(result_lines_arr) | |||
| return result_lines | |||
| def _read_parent_blocks(self, block_id: str, num_tabs: int = 0) -> List[str]: | |||
| def _read_block(self, block_id: str, num_tabs: int = 0) -> str: | |||
| """Read a block.""" | |||
| done = False | |||
| result_lines_arr = [] | |||
| cur_block_id = block_id | |||
| while not done: | |||
| while True: | |||
| block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) | |||
| query_dict: Dict[str, Any] = {} | |||
| res = requests.request( | |||
| "GET", block_url, headers=self.headers, json=query_dict | |||
| "GET", | |||
| block_url, | |||
| headers={ | |||
| "Authorization": "Bearer " + self._notion_access_token, | |||
| "Content-Type": "application/json", | |||
| "Notion-Version": "2022-06-28", | |||
| }, | |||
| json=query_dict | |||
| ) | |||
| data = res.json() | |||
| # current block's heading | |||
| if 'results' not in data or data["results"] is None: | |||
| break | |||
| heading = '' | |||
| for result in data["results"]: | |||
| result_type = result["type"] | |||
| @@ -171,7 +228,6 @@ class NotionPageReader(BaseReader): | |||
| if result_type == 'table': | |||
| result_block_id = result["id"] | |||
| text = self._read_table_rows(result_block_id) | |||
| text += "\n\n" | |||
| result_lines_arr.append(text) | |||
| else: | |||
| if "rich_text" in result_obj: | |||
| @@ -179,10 +235,10 @@ class NotionPageReader(BaseReader): | |||
| # skip if doesn't have text object | |||
| if "text" in rich_text: | |||
| text = rich_text["text"]["content"] | |||
| cur_result_text_arr.append(text) | |||
| prefix = "\t" * num_tabs | |||
| cur_result_text_arr.append(prefix + text) | |||
| if result_type in HEADING_TYPE: | |||
| heading = text | |||
| result_block_id = result["id"] | |||
| has_children = result["has_children"] | |||
| block_type = result["type"] | |||
| @@ -193,177 +249,121 @@ class NotionPageReader(BaseReader): | |||
| cur_result_text_arr.append(children_text) | |||
| cur_result_text = "\n".join(cur_result_text_arr) | |||
| cur_result_text += "\n\n" | |||
| if result_type in HEADING_TYPE: | |||
| result_lines_arr.append(cur_result_text) | |||
| else: | |||
| result_lines_arr.append(f'{heading}\n{cur_result_text}') | |||
| if data["next_cursor"] is None: | |||
| done = True | |||
| break | |||
| else: | |||
| cur_block_id = data["next_cursor"] | |||
| return result_lines_arr | |||
| def read_page(self, page_id: str) -> str: | |||
| """Read a page.""" | |||
| return self._read_block(page_id) | |||
| def read_page_as_documents(self, page_id: str) -> List[str]: | |||
| """Read a page as documents.""" | |||
| return self._read_parent_blocks(page_id) | |||
| def query_database_data( | |||
| self, database_id: str, query_dict: Dict[str, Any] = {} | |||
| ) -> str: | |||
| """Get all the pages from a Notion database.""" | |||
| res = requests.post\ | |||
| ( | |||
| DATABASE_URL_TMPL.format(database_id=database_id), | |||
| headers=self.headers, | |||
| json=query_dict, | |||
| ) | |||
| data = res.json() | |||
| database_content_list = [] | |||
| if 'results' not in data or data["results"] is None: | |||
| return "" | |||
| for result in data["results"]: | |||
| properties = result['properties'] | |||
| data = {} | |||
| for property_name, property_value in properties.items(): | |||
| type = property_value['type'] | |||
| if type == 'multi_select': | |||
| value = [] | |||
| multi_select_list = property_value[type] | |||
| for multi_select in multi_select_list: | |||
| value.append(multi_select['name']) | |||
| elif type == 'rich_text' or type == 'title': | |||
| if len(property_value[type]) > 0: | |||
| value = property_value[type][0]['plain_text'] | |||
| else: | |||
| value = '' | |||
| elif type == 'select' or type == 'status': | |||
| if property_value[type]: | |||
| value = property_value[type]['name'] | |||
| else: | |||
| value = '' | |||
| else: | |||
| value = property_value[type] | |||
| data[property_name] = value | |||
| database_content_list.append(json.dumps(data, ensure_ascii=False)) | |||
| return "\n\n".join(database_content_list) | |||
| def query_database( | |||
| self, database_id: str, query_dict: Dict[str, Any] = {} | |||
| ) -> List[str]: | |||
| """Get all the pages from a Notion database.""" | |||
| res = requests.post\ | |||
| ( | |||
| DATABASE_URL_TMPL.format(database_id=database_id), | |||
| headers=self.headers, | |||
| json=query_dict, | |||
| ) | |||
| data = res.json() | |||
| page_ids = [] | |||
| for result in data["results"]: | |||
| page_id = result["id"] | |||
| page_ids.append(page_id) | |||
| return page_ids | |||
| result_lines = "\n".join(result_lines_arr) | |||
| return result_lines | |||
| def search(self, query: str) -> List[str]: | |||
| """Search Notion page given a text query.""" | |||
| def _read_table_rows(self, block_id: str) -> str: | |||
| """Read table rows.""" | |||
| done = False | |||
| next_cursor: Optional[str] = None | |||
| page_ids = [] | |||
| result_lines_arr = [] | |||
| cur_block_id = block_id | |||
| while not done: | |||
| query_dict = { | |||
| "query": query, | |||
| } | |||
| if next_cursor is not None: | |||
| query_dict["start_cursor"] = next_cursor | |||
| res = requests.post(SEARCH_URL, headers=self.headers, json=query_dict) | |||
| block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) | |||
| query_dict: Dict[str, Any] = {} | |||
| res = requests.request( | |||
| "GET", | |||
| block_url, | |||
| headers={ | |||
| "Authorization": "Bearer " + self._notion_access_token, | |||
| "Content-Type": "application/json", | |||
| "Notion-Version": "2022-06-28", | |||
| }, | |||
| json=query_dict | |||
| ) | |||
| data = res.json() | |||
| for result in data["results"]: | |||
| page_id = result["id"] | |||
| page_ids.append(page_id) | |||
| # get table headers text | |||
| table_header_cell_texts = [] | |||
| tabel_header_cells = data["results"][0]['table_row']['cells'] | |||
| for tabel_header_cell in tabel_header_cells: | |||
| if tabel_header_cell: | |||
| for table_header_cell_text in tabel_header_cell: | |||
| text = table_header_cell_text["text"]["content"] | |||
| table_header_cell_texts.append(text) | |||
| # get table columns text and format | |||
| results = data["results"] | |||
| for i in range(len(results) - 1): | |||
| column_texts = [] | |||
| tabel_column_cells = data["results"][i + 1]['table_row']['cells'] | |||
| for j in range(len(tabel_column_cells)): | |||
| if tabel_column_cells[j]: | |||
| for table_column_cell_text in tabel_column_cells[j]: | |||
| column_text = table_column_cell_text["text"]["content"] | |||
| column_texts.append(f'{table_header_cell_texts[j]}:{column_text}') | |||
| cur_result_text = "\n".join(column_texts) | |||
| result_lines_arr.append(cur_result_text) | |||
| if data["next_cursor"] is None: | |||
| done = True | |||
| break | |||
| else: | |||
| next_cursor = data["next_cursor"] | |||
| return page_ids | |||
| cur_block_id = data["next_cursor"] | |||
| def load_data( | |||
| self, page_ids: List[str] = [], database_id: Optional[str] = None | |||
| ) -> List[Document]: | |||
| """Load data from the input directory. | |||
| result_lines = "\n".join(result_lines_arr) | |||
| return result_lines | |||
| Args: | |||
| page_ids (List[str]): List of page ids to load. | |||
| def update_last_edited_time(self, document_model: DocumentModel): | |||
| if not document_model: | |||
| return | |||
| Returns: | |||
| List[Document]: List of documents. | |||
| last_edited_time = self.get_notion_last_edited_time() | |||
| data_source_info = document_model.data_source_info_dict | |||
| data_source_info['last_edited_time'] = last_edited_time | |||
| update_params = { | |||
| DocumentModel.data_source_info: json.dumps(data_source_info) | |||
| } | |||
| """ | |||
| if not page_ids and not database_id: | |||
| raise ValueError("Must specify either `page_ids` or `database_id`.") | |||
| docs = [] | |||
| if database_id is not None: | |||
| # get all the pages in the database | |||
| page_ids = self.query_database(database_id) | |||
| for page_id in page_ids: | |||
| page_text = self.read_page(page_id) | |||
| docs.append(Document(page_text)) | |||
| else: | |||
| for page_id in page_ids: | |||
| page_text = self.read_page(page_id) | |||
| docs.append(Document(page_text)) | |||
| DocumentModel.query.filter_by(id=document_model.id).update(update_params) | |||
| db.session.commit() | |||
| return docs | |||
| def load_data_as_documents( | |||
| self, page_ids: List[str] = [], database_id: Optional[str] = None | |||
| ) -> List[Document]: | |||
| if not page_ids and not database_id: | |||
| raise ValueError("Must specify either `page_ids` or `database_id`.") | |||
| docs = [] | |||
| if database_id is not None: | |||
| # get all the pages in the database | |||
| page_text = self.query_database_data(database_id) | |||
| docs.append(Document(page_text)) | |||
| def get_notion_last_edited_time(self) -> str: | |||
| obj_id = self._notion_obj_id | |||
| page_type = self._notion_page_type | |||
| if page_type == 'database': | |||
| retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id) | |||
| else: | |||
| for page_id in page_ids: | |||
| page_text_list = self.read_page_as_documents(page_id) | |||
| for page_text in page_text_list: | |||
| docs.append(Document(page_text)) | |||
| retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id) | |||
| return docs | |||
| def get_page_last_edited_time(self, page_id: str) -> str: | |||
| retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=page_id) | |||
| query_dict: Dict[str, Any] = {} | |||
| res = requests.request( | |||
| "GET", retrieve_page_url, headers=self.headers, json=query_dict | |||
| "GET", | |||
| retrieve_page_url, | |||
| headers={ | |||
| "Authorization": "Bearer " + self._notion_access_token, | |||
| "Content-Type": "application/json", | |||
| "Notion-Version": "2022-06-28", | |||
| }, | |||
| json=query_dict | |||
| ) | |||
| data = res.json() | |||
| return data["last_edited_time"] | |||
| def get_database_last_edited_time(self, database_id: str) -> str: | |||
| retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=database_id) | |||
| query_dict: Dict[str, Any] = {} | |||
| res = requests.request( | |||
| "GET", retrieve_page_url, headers=self.headers, json=query_dict | |||
| ) | |||
| data = res.json() | |||
| return data["last_edited_time"] | |||
| @classmethod | |||
| def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: | |||
| data_source_binding = DataSourceBinding.query.filter( | |||
| db.and_( | |||
| DataSourceBinding.tenant_id == tenant_id, | |||
| DataSourceBinding.provider == 'notion', | |||
| DataSourceBinding.disabled == False, | |||
| DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"' | |||
| ) | |||
| ).first() | |||
| if not data_source_binding: | |||
| raise Exception(f'No notion data source binding found for tenant {tenant_id} ' | |||
| f'and notion workspace {notion_workspace_id}') | |||
| if __name__ == "__main__": | |||
| reader = NotionPageReader() | |||
| logger.info(reader.search("What I")) | |||
| return data_source_binding.access_token | |||
| @@ -0,0 +1,55 @@ | |||
| import logging | |||
| from typing import List, Optional | |||
| from langchain.document_loaders import PyPDFium2Loader | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.schema import Document | |||
| from extensions.ext_storage import storage | |||
| from models.model import UploadFile | |||
| logger = logging.getLogger(__name__) | |||
| class PdfLoader(BaseLoader): | |||
| """Load pdf files. | |||
| Args: | |||
| file_path: Path to the file to load. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| file_path: str, | |||
| upload_file: Optional[UploadFile] = None | |||
| ): | |||
| """Initialize with file path.""" | |||
| self._file_path = file_path | |||
| self._upload_file = upload_file | |||
| def load(self) -> List[Document]: | |||
| plaintext_file_key = '' | |||
| plaintext_file_exists = False | |||
| if self._upload_file: | |||
| if self._upload_file.hash: | |||
| plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \ | |||
| + self._upload_file.hash + '.0625.plaintext' | |||
| try: | |||
| text = storage.load(plaintext_file_key).decode('utf-8') | |||
| plaintext_file_exists = True | |||
| return [Document(page_content=text)] | |||
| except FileNotFoundError: | |||
| pass | |||
| documents = PyPDFium2Loader(file_path=self._file_path).load() | |||
| text_list = [] | |||
| for document in documents: | |||
| text_list.append(document.page_content) | |||
| text = "\n\n".join(text_list) | |||
| # save plaintext file for caching | |||
| if not plaintext_file_exists and plaintext_file_key: | |||
| storage.save(plaintext_file_key, text.encode('utf-8')) | |||
| return documents | |||
| @@ -1,10 +1,6 @@ | |||
| from typing import Any, Dict, Optional, Sequence | |||
| import tiktoken | |||
| from llama_index.data_structs import Node | |||
| from llama_index.docstore.types import BaseDocumentStore | |||
| from llama_index.docstore.utils import json_to_doc | |||
| from llama_index.schema import BaseDocument | |||
| from langchain.schema import Document | |||
| from sqlalchemy import func | |||
| from core.llm.token_calculator import TokenCalculator | |||
| @@ -12,7 +8,7 @@ from extensions.ext_database import db | |||
| from models.dataset import Dataset, DocumentSegment | |||
| class DatesetDocumentStore(BaseDocumentStore): | |||
| class DatesetDocumentStore: | |||
| def __init__( | |||
| self, | |||
| dataset: Dataset, | |||
| @@ -48,7 +44,7 @@ class DatesetDocumentStore(BaseDocumentStore): | |||
| return self._embedding_model_name | |||
| @property | |||
| def docs(self) -> Dict[str, BaseDocument]: | |||
| def docs(self) -> Dict[str, Document]: | |||
| document_segments = db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.dataset_id == self._dataset.id | |||
| ).all() | |||
| @@ -56,13 +52,20 @@ class DatesetDocumentStore(BaseDocumentStore): | |||
| output = {} | |||
| for document_segment in document_segments: | |||
| doc_id = document_segment.index_node_id | |||
| result = self.segment_to_dict(document_segment) | |||
| output[doc_id] = json_to_doc(result) | |||
| output[doc_id] = Document( | |||
| page_content=document_segment.content, | |||
| metadata={ | |||
| "doc_id": document_segment.index_node_id, | |||
| "doc_hash": document_segment.index_node_hash, | |||
| "document_id": document_segment.document_id, | |||
| "dataset_id": document_segment.dataset_id, | |||
| } | |||
| ) | |||
| return output | |||
| def add_documents( | |||
| self, docs: Sequence[BaseDocument], allow_update: bool = True | |||
| self, docs: Sequence[Document], allow_update: bool = True | |||
| ) -> None: | |||
| max_position = db.session.query(func.max(DocumentSegment.position)).filter( | |||
| DocumentSegment.document == self._document_id | |||
| @@ -72,23 +75,20 @@ class DatesetDocumentStore(BaseDocumentStore): | |||
| max_position = 0 | |||
| for doc in docs: | |||
| if doc.is_doc_id_none: | |||
| raise ValueError("doc_id not set") | |||
| if not isinstance(doc, Document): | |||
| raise ValueError("doc must be a Document") | |||
| if not isinstance(doc, Node): | |||
| raise ValueError("doc must be a Node") | |||
| segment_document = self.get_document(doc_id=doc.get_doc_id(), raise_error=False) | |||
| segment_document = self.get_document(doc_id=doc.metadata['doc_id'], raise_error=False) | |||
| # NOTE: doc could already exist in the store, but we overwrite it | |||
| if not allow_update and segment_document: | |||
| raise ValueError( | |||
| f"doc_id {doc.get_doc_id()} already exists. " | |||
| f"doc_id {doc.metadata['doc_id']} already exists. " | |||
| "Set allow_update to True to overwrite." | |||
| ) | |||
| # calc embedding use tokens | |||
| tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.get_text()) | |||
| tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.page_content) | |||
| if not segment_document: | |||
| max_position += 1 | |||
| @@ -97,19 +97,19 @@ class DatesetDocumentStore(BaseDocumentStore): | |||
| tenant_id=self._dataset.tenant_id, | |||
| dataset_id=self._dataset.id, | |||
| document_id=self._document_id, | |||
| index_node_id=doc.get_doc_id(), | |||
| index_node_hash=doc.get_doc_hash(), | |||
| index_node_id=doc.metadata['doc_id'], | |||
| index_node_hash=doc.metadata['doc_hash'], | |||
| position=max_position, | |||
| content=doc.get_text(), | |||
| word_count=len(doc.get_text()), | |||
| content=doc.page_content, | |||
| word_count=len(doc.page_content), | |||
| tokens=tokens, | |||
| created_by=self._user_id, | |||
| ) | |||
| db.session.add(segment_document) | |||
| else: | |||
| segment_document.content = doc.get_text() | |||
| segment_document.index_node_hash = doc.get_doc_hash() | |||
| segment_document.word_count = len(doc.get_text()) | |||
| segment_document.content = doc.page_content | |||
| segment_document.index_node_hash = doc.metadata['doc_hash'] | |||
| segment_document.word_count = len(doc.page_content) | |||
| segment_document.tokens = tokens | |||
| db.session.commit() | |||
| @@ -121,7 +121,7 @@ class DatesetDocumentStore(BaseDocumentStore): | |||
| def get_document( | |||
| self, doc_id: str, raise_error: bool = True | |||
| ) -> Optional[BaseDocument]: | |||
| ) -> Optional[Document]: | |||
| document_segment = self.get_document_segment(doc_id) | |||
| if document_segment is None: | |||
| @@ -130,8 +130,15 @@ class DatesetDocumentStore(BaseDocumentStore): | |||
| else: | |||
| return None | |||
| result = self.segment_to_dict(document_segment) | |||
| return json_to_doc(result) | |||
| return Document( | |||
| page_content=document_segment.content, | |||
| metadata={ | |||
| "doc_id": document_segment.index_node_id, | |||
| "doc_hash": document_segment.index_node_hash, | |||
| "document_id": document_segment.document_id, | |||
| "dataset_id": document_segment.dataset_id, | |||
| } | |||
| ) | |||
| def delete_document(self, doc_id: str, raise_error: bool = True) -> None: | |||
| document_segment = self.get_document_segment(doc_id) | |||
| @@ -164,15 +171,6 @@ class DatesetDocumentStore(BaseDocumentStore): | |||
| return document_segment.index_node_hash | |||
| def update_docstore(self, other: "BaseDocumentStore") -> None: | |||
| """Update docstore. | |||
| Args: | |||
| other (BaseDocumentStore): docstore to update from | |||
| """ | |||
| self.add_documents(list(other.docs.values())) | |||
| def get_document_segment(self, doc_id: str) -> DocumentSegment: | |||
| document_segment = db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.dataset_id == self._dataset.id, | |||
| @@ -180,11 +178,3 @@ class DatesetDocumentStore(BaseDocumentStore): | |||
| ).first() | |||
| return document_segment | |||
| def segment_to_dict(self, segment: DocumentSegment) -> Dict[str, Any]: | |||
| return { | |||
| "doc_id": segment.index_node_id, | |||
| "doc_hash": segment.index_node_hash, | |||
| "text": segment.content, | |||
| "__type__": Node.get_type() | |||
| } | |||
| @@ -1,51 +0,0 @@ | |||
| from typing import Any, Dict, Optional, Sequence | |||
| from llama_index.docstore.types import BaseDocumentStore | |||
| from llama_index.schema import BaseDocument | |||
| class EmptyDocumentStore(BaseDocumentStore): | |||
| @classmethod | |||
| def from_dict(cls, config_dict: Dict[str, Any]) -> "EmptyDocumentStore": | |||
| return cls() | |||
| def to_dict(self) -> Dict[str, Any]: | |||
| """Serialize to dict.""" | |||
| return {} | |||
| @property | |||
| def docs(self) -> Dict[str, BaseDocument]: | |||
| return {} | |||
| def add_documents( | |||
| self, docs: Sequence[BaseDocument], allow_update: bool = True | |||
| ) -> None: | |||
| pass | |||
| def document_exists(self, doc_id: str) -> bool: | |||
| """Check if document exists.""" | |||
| return False | |||
| def get_document( | |||
| self, doc_id: str, raise_error: bool = True | |||
| ) -> Optional[BaseDocument]: | |||
| return None | |||
| def delete_document(self, doc_id: str, raise_error: bool = True) -> None: | |||
| pass | |||
| def set_document_hash(self, doc_id: str, doc_hash: str) -> None: | |||
| """Set the hash for a given doc_id.""" | |||
| pass | |||
| def get_document_hash(self, doc_id: str) -> Optional[str]: | |||
| """Get the stored hash for a document, if it exists.""" | |||
| return None | |||
| def update_docstore(self, other: "BaseDocumentStore") -> None: | |||
| """Update docstore. | |||
| Args: | |||
| other (BaseDocumentStore): docstore to update from | |||
| """ | |||
| self.add_documents(list(other.docs.values())) | |||
| @@ -0,0 +1,72 @@ | |||
| import logging | |||
| from typing import List | |||
| from langchain.embeddings.base import Embeddings | |||
| from sqlalchemy.exc import IntegrityError | |||
| from extensions.ext_database import db | |||
| from libs import helper | |||
| from models.dataset import Embedding | |||
| class CacheEmbedding(Embeddings): | |||
| def __init__(self, embeddings: Embeddings): | |||
| self._embeddings = embeddings | |||
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |||
| """Embed search docs.""" | |||
| # use doc embedding cache or store if not exists | |||
| text_embeddings = [] | |||
| embedding_queue_texts = [] | |||
| for text in texts: | |||
| hash = helper.generate_text_hash(text) | |||
| embedding = db.session.query(Embedding).filter_by(hash=hash).first() | |||
| if embedding: | |||
| text_embeddings.append(embedding.get_embedding()) | |||
| else: | |||
| embedding_queue_texts.append(text) | |||
| embedding_results = self._embeddings.embed_documents(embedding_queue_texts) | |||
| i = 0 | |||
| for text in embedding_queue_texts: | |||
| hash = helper.generate_text_hash(text) | |||
| try: | |||
| embedding = Embedding(hash=hash) | |||
| embedding.set_embedding(embedding_results[i]) | |||
| db.session.add(embedding) | |||
| db.session.commit() | |||
| except IntegrityError: | |||
| db.session.rollback() | |||
| continue | |||
| except: | |||
| logging.exception('Failed to add embedding to db') | |||
| continue | |||
| i += 1 | |||
| text_embeddings.extend(embedding_results) | |||
| return text_embeddings | |||
| def embed_query(self, text: str) -> List[float]: | |||
| """Embed query text.""" | |||
| # use doc embedding cache or store if not exists | |||
| hash = helper.generate_text_hash(text) | |||
| embedding = db.session.query(Embedding).filter_by(hash=hash).first() | |||
| if embedding: | |||
| return embedding.get_embedding() | |||
| embedding_results = self._embeddings.embed_query(text) | |||
| try: | |||
| embedding = Embedding(hash=hash) | |||
| embedding.set_embedding(embedding_results) | |||
| db.session.add(embedding) | |||
| db.session.commit() | |||
| except IntegrityError: | |||
| db.session.rollback() | |||
| except: | |||
| logging.exception('Failed to add embedding to db') | |||
| return embedding_results | |||
| @@ -1,214 +0,0 @@ | |||
| from typing import Optional, Any, List | |||
| import openai | |||
| from llama_index.embeddings.base import BaseEmbedding | |||
| from llama_index.embeddings.openai import OpenAIEmbeddingMode, OpenAIEmbeddingModelType, _QUERY_MODE_MODEL_DICT, \ | |||
| _TEXT_MODE_MODEL_DICT | |||
| from tenacity import wait_random_exponential, retry, stop_after_attempt | |||
| from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async | |||
| @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) | |||
| def get_embedding( | |||
| text: str, | |||
| engine: Optional[str] = None, | |||
| api_key: Optional[str] = None, | |||
| **kwargs | |||
| ) -> List[float]: | |||
| """Get embedding. | |||
| NOTE: Copied from OpenAI's embedding utils: | |||
| https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py | |||
| Copied here to avoid importing unnecessary dependencies | |||
| like matplotlib, plotly, scipy, sklearn. | |||
| """ | |||
| text = text.replace("\n", " ") | |||
| return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"] | |||
| @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) | |||
| async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[ | |||
| float]: | |||
| """Asynchronously get embedding. | |||
| NOTE: Copied from OpenAI's embedding utils: | |||
| https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py | |||
| Copied here to avoid importing unnecessary dependencies | |||
| like matplotlib, plotly, scipy, sklearn. | |||
| """ | |||
| # replace newlines, which can negatively affect performance. | |||
| text = text.replace("\n", " ") | |||
| return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][ | |||
| "embedding" | |||
| ] | |||
| @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) | |||
| def get_embeddings( | |||
| list_of_text: List[str], | |||
| engine: Optional[str] = None, | |||
| api_key: Optional[str] = None, | |||
| **kwargs | |||
| ) -> List[List[float]]: | |||
| """Get embeddings. | |||
| NOTE: Copied from OpenAI's embedding utils: | |||
| https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py | |||
| Copied here to avoid importing unnecessary dependencies | |||
| like matplotlib, plotly, scipy, sklearn. | |||
| """ | |||
| assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048." | |||
| # replace newlines, which can negatively affect performance. | |||
| list_of_text = [text.replace("\n", " ") for text in list_of_text] | |||
| data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data | |||
| data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. | |||
| return [d["embedding"] for d in data] | |||
| @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) | |||
| async def aget_embeddings( | |||
| list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs | |||
| ) -> List[List[float]]: | |||
| """Asynchronously get embeddings. | |||
| NOTE: Copied from OpenAI's embedding utils: | |||
| https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py | |||
| Copied here to avoid importing unnecessary dependencies | |||
| like matplotlib, plotly, scipy, sklearn. | |||
| """ | |||
| assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048." | |||
| # replace newlines, which can negatively affect performance. | |||
| list_of_text = [text.replace("\n", " ") for text in list_of_text] | |||
| data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data | |||
| data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. | |||
| return [d["embedding"] for d in data] | |||
| class OpenAIEmbedding(BaseEmbedding): | |||
| def __init__( | |||
| self, | |||
| mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE, | |||
| model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002, | |||
| deployment_name: Optional[str] = None, | |||
| openai_api_key: Optional[str] = None, | |||
| **kwargs: Any, | |||
| ) -> None: | |||
| """Init params.""" | |||
| new_kwargs = {} | |||
| if 'embed_batch_size' in kwargs: | |||
| new_kwargs['embed_batch_size'] = kwargs['embed_batch_size'] | |||
| if 'tokenizer' in kwargs: | |||
| new_kwargs['tokenizer'] = kwargs['tokenizer'] | |||
| super().__init__(**new_kwargs) | |||
| self.mode = OpenAIEmbeddingMode(mode) | |||
| self.model = OpenAIEmbeddingModelType(model) | |||
| self.deployment_name = deployment_name | |||
| self.openai_api_key = openai_api_key | |||
| self.openai_api_type = kwargs.get('openai_api_type') | |||
| self.openai_api_version = kwargs.get('openai_api_version') | |||
| self.openai_api_base = kwargs.get('openai_api_base') | |||
| @handle_llm_exceptions | |||
| def _get_query_embedding(self, query: str) -> List[float]: | |||
| """Get query embedding.""" | |||
| if self.deployment_name is not None: | |||
| engine = self.deployment_name | |||
| else: | |||
| key = (self.mode, self.model) | |||
| if key not in _QUERY_MODE_MODEL_DICT: | |||
| raise ValueError(f"Invalid mode, model combination: {key}") | |||
| engine = _QUERY_MODE_MODEL_DICT[key] | |||
| return get_embedding(query, engine=engine, api_key=self.openai_api_key, | |||
| api_type=self.openai_api_type, api_version=self.openai_api_version, | |||
| api_base=self.openai_api_base) | |||
| def _get_text_embedding(self, text: str) -> List[float]: | |||
| """Get text embedding.""" | |||
| if self.deployment_name is not None: | |||
| engine = self.deployment_name | |||
| else: | |||
| key = (self.mode, self.model) | |||
| if key not in _TEXT_MODE_MODEL_DICT: | |||
| raise ValueError(f"Invalid mode, model combination: {key}") | |||
| engine = _TEXT_MODE_MODEL_DICT[key] | |||
| return get_embedding(text, engine=engine, api_key=self.openai_api_key, | |||
| api_type=self.openai_api_type, api_version=self.openai_api_version, | |||
| api_base=self.openai_api_base) | |||
| async def _aget_text_embedding(self, text: str) -> List[float]: | |||
| """Asynchronously get text embedding.""" | |||
| if self.deployment_name is not None: | |||
| engine = self.deployment_name | |||
| else: | |||
| key = (self.mode, self.model) | |||
| if key not in _TEXT_MODE_MODEL_DICT: | |||
| raise ValueError(f"Invalid mode, model combination: {key}") | |||
| engine = _TEXT_MODE_MODEL_DICT[key] | |||
| return await aget_embedding(text, engine=engine, api_key=self.openai_api_key, | |||
| api_type=self.openai_api_type, api_version=self.openai_api_version, | |||
| api_base=self.openai_api_base) | |||
| def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: | |||
| """Get text embeddings. | |||
| By default, this is a wrapper around _get_text_embedding. | |||
| Can be overriden for batch queries. | |||
| """ | |||
| if self.openai_api_type and self.openai_api_type == 'azure': | |||
| embeddings = [] | |||
| for text in texts: | |||
| embeddings.append(self._get_text_embedding(text)) | |||
| return embeddings | |||
| if self.deployment_name is not None: | |||
| engine = self.deployment_name | |||
| else: | |||
| key = (self.mode, self.model) | |||
| if key not in _TEXT_MODE_MODEL_DICT: | |||
| raise ValueError(f"Invalid mode, model combination: {key}") | |||
| engine = _TEXT_MODE_MODEL_DICT[key] | |||
| embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key, | |||
| api_type=self.openai_api_type, api_version=self.openai_api_version, | |||
| api_base=self.openai_api_base) | |||
| return embeddings | |||
| async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: | |||
| """Asynchronously get text embeddings.""" | |||
| if self.openai_api_type and self.openai_api_type == 'azure': | |||
| embeddings = [] | |||
| for text in texts: | |||
| embeddings.append(await self._aget_text_embedding(text)) | |||
| return embeddings | |||
| if self.deployment_name is not None: | |||
| engine = self.deployment_name | |||
| else: | |||
| key = (self.mode, self.model) | |||
| if key not in _TEXT_MODE_MODEL_DICT: | |||
| raise ValueError(f"Invalid mode, model combination: {key}") | |||
| engine = _TEXT_MODE_MODEL_DICT[key] | |||
| embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key, | |||
| api_type=self.openai_api_type, api_version=self.openai_api_version, | |||
| api_base=self.openai_api_base) | |||
| return embeddings | |||
| @@ -0,0 +1,59 @@ | |||
| from __future__ import annotations | |||
| from abc import abstractmethod, ABC | |||
| from typing import List, Any | |||
| from langchain.schema import Document, BaseRetriever | |||
| from models.dataset import Dataset | |||
| class BaseIndex(ABC): | |||
| def __init__(self, dataset: Dataset): | |||
| self.dataset = dataset | |||
| @abstractmethod | |||
| def create(self, texts: list[Document], **kwargs) -> BaseIndex: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def add_texts(self, texts: list[Document], **kwargs): | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def text_exists(self, id: str) -> bool: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def delete_by_document_id(self, document_id: str): | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def get_retriever(self, **kwargs: Any) -> BaseRetriever: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def search( | |||
| self, query: str, | |||
| **kwargs: Any | |||
| ) -> List[Document]: | |||
| raise NotImplementedError | |||
| def delete(self) -> None: | |||
| raise NotImplementedError | |||
| def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: | |||
| for text in texts: | |||
| doc_id = text.metadata['doc_id'] | |||
| exists_duplicate_node = self.text_exists(doc_id) | |||
| if exists_duplicate_node: | |||
| texts.remove(text) | |||
| return texts | |||
| def _get_uuids(self, texts: list[Document]) -> list[str]: | |||
| return [text.metadata['doc_id'] for text in texts] | |||
| @@ -0,0 +1,41 @@ | |||
| from flask import current_app | |||
| from langchain.embeddings import OpenAIEmbeddings | |||
| from core.embedding.cached_embedding import CacheEmbedding | |||
| from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig | |||
| from core.index.vector_index.vector_index import VectorIndex | |||
| from core.llm.llm_builder import LLMBuilder | |||
| from models.dataset import Dataset | |||
| class IndexBuilder: | |||
| @classmethod | |||
| def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False): | |||
| if indexing_technique == "high_quality": | |||
| if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality': | |||
| return None | |||
| model_credentials = LLMBuilder.get_model_credentials( | |||
| tenant_id=dataset.tenant_id, | |||
| model_provider=LLMBuilder.get_default_provider(dataset.tenant_id), | |||
| model_name='text-embedding-ada-002' | |||
| ) | |||
| embeddings = CacheEmbedding(OpenAIEmbeddings( | |||
| **model_credentials | |||
| )) | |||
| return VectorIndex( | |||
| dataset=dataset, | |||
| config=current_app.config, | |||
| embeddings=embeddings | |||
| ) | |||
| elif indexing_technique == "economy": | |||
| return KeywordTableIndex( | |||
| dataset=dataset, | |||
| config=KeywordTableConfig( | |||
| max_keywords_per_chunk=10 | |||
| ) | |||
| ) | |||
| else: | |||
| raise ValueError('Unknown indexing technique') | |||
| @@ -1,60 +0,0 @@ | |||
| from langchain.callbacks import CallbackManager | |||
| from llama_index import ServiceContext, PromptHelper, LLMPredictor | |||
| from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler | |||
| from core.embedding.openai_embedding import OpenAIEmbedding | |||
| from core.llm.llm_builder import LLMBuilder | |||
| class IndexBuilder: | |||
| @classmethod | |||
| def get_default_service_context(cls, tenant_id: str) -> ServiceContext: | |||
| # set number of output tokens | |||
| num_output = 512 | |||
| # only for verbose | |||
| callback_manager = CallbackManager([DifyStdOutCallbackHandler()]) | |||
| llm = LLMBuilder.to_llm( | |||
| tenant_id=tenant_id, | |||
| model_name='text-davinci-003', | |||
| temperature=0, | |||
| max_tokens=num_output, | |||
| callback_manager=callback_manager, | |||
| ) | |||
| llm_predictor = LLMPredictor(llm=llm) | |||
| # These parameters here will affect the logic of segmenting the final synthesized response. | |||
| # The number of refinement iterations in the synthesis process depends | |||
| # on whether the length of the segmented output exceeds the max_input_size. | |||
| prompt_helper = PromptHelper( | |||
| max_input_size=3500, | |||
| num_output=num_output, | |||
| max_chunk_overlap=20 | |||
| ) | |||
| provider = LLMBuilder.get_default_provider(tenant_id) | |||
| model_credentials = LLMBuilder.get_model_credentials( | |||
| tenant_id=tenant_id, | |||
| model_provider=provider, | |||
| model_name='text-embedding-ada-002' | |||
| ) | |||
| return ServiceContext.from_defaults( | |||
| llm_predictor=llm_predictor, | |||
| prompt_helper=prompt_helper, | |||
| embed_model=OpenAIEmbedding(**model_credentials), | |||
| ) | |||
| @classmethod | |||
| def get_fake_llm_service_context(cls, tenant_id: str) -> ServiceContext: | |||
| llm = LLMBuilder.to_llm( | |||
| tenant_id=tenant_id, | |||
| model_name='fake' | |||
| ) | |||
| return ServiceContext.from_defaults( | |||
| llm_predictor=LLMPredictor(llm=llm), | |||
| embed_model=OpenAIEmbedding() | |||
| ) | |||
| @@ -1,159 +0,0 @@ | |||
| import re | |||
| from typing import ( | |||
| Any, | |||
| Dict, | |||
| List, | |||
| Set, | |||
| Optional | |||
| ) | |||
| import jieba.analyse | |||
| from core.index.keyword_table.stopwords import STOPWORDS | |||
| from llama_index.indices.query.base import IS | |||
| from llama_index import QueryMode | |||
| from llama_index.indices.base import QueryMap | |||
| from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex | |||
| from llama_index.indices.keyword_table.query import BaseGPTKeywordTableQuery | |||
| from llama_index.docstore import BaseDocumentStore | |||
| from llama_index.indices.postprocessor.node import ( | |||
| BaseNodePostprocessor, | |||
| ) | |||
| from llama_index.indices.response.response_builder import ResponseMode | |||
| from llama_index.indices.service_context import ServiceContext | |||
| from llama_index.optimization.optimizer import BaseTokenUsageOptimizer | |||
| from llama_index.prompts.prompts import ( | |||
| QuestionAnswerPrompt, | |||
| RefinePrompt, | |||
| SimpleInputPrompt, | |||
| ) | |||
| from core.index.query.synthesizer import EnhanceResponseSynthesizer | |||
| def jieba_extract_keywords( | |||
| text_chunk: str, | |||
| max_keywords: Optional[int] = None, | |||
| expand_with_subtokens: bool = True, | |||
| ) -> Set[str]: | |||
| """Extract keywords with JIEBA tfidf.""" | |||
| keywords = jieba.analyse.extract_tags( | |||
| sentence=text_chunk, | |||
| topK=max_keywords, | |||
| ) | |||
| if expand_with_subtokens: | |||
| return set(expand_tokens_with_subtokens(keywords)) | |||
| else: | |||
| return set(keywords) | |||
| def expand_tokens_with_subtokens(tokens: Set[str]) -> Set[str]: | |||
| """Get subtokens from a list of tokens., filtering for stopwords.""" | |||
| results = set() | |||
| for token in tokens: | |||
| results.add(token) | |||
| sub_tokens = re.findall(r"\w+", token) | |||
| if len(sub_tokens) > 1: | |||
| results.update({w for w in sub_tokens if w not in list(STOPWORDS)}) | |||
| return results | |||
| class GPTJIEBAKeywordTableIndex(BaseGPTKeywordTableIndex): | |||
| """GPT JIEBA Keyword Table Index. | |||
| This index uses a JIEBA keyword extractor to extract keywords from the text. | |||
| """ | |||
| def _extract_keywords(self, text: str) -> Set[str]: | |||
| """Extract keywords from text.""" | |||
| return jieba_extract_keywords(text, max_keywords=self.max_keywords_per_chunk) | |||
| @classmethod | |||
| def get_query_map(self) -> QueryMap: | |||
| """Get query map.""" | |||
| super_map = super().get_query_map() | |||
| super_map[QueryMode.DEFAULT] = GPTKeywordTableJIEBAQuery | |||
| return super_map | |||
| def _delete(self, doc_id: str, **delete_kwargs: Any) -> None: | |||
| """Delete a document.""" | |||
| # get set of ids that correspond to node | |||
| node_idxs_to_delete = {doc_id} | |||
| # delete node_idxs from keyword to node idxs mapping | |||
| keywords_to_delete = set() | |||
| for keyword, node_idxs in self._index_struct.table.items(): | |||
| if node_idxs_to_delete.intersection(node_idxs): | |||
| self._index_struct.table[keyword] = node_idxs.difference( | |||
| node_idxs_to_delete | |||
| ) | |||
| if not self._index_struct.table[keyword]: | |||
| keywords_to_delete.add(keyword) | |||
| for keyword in keywords_to_delete: | |||
| del self._index_struct.table[keyword] | |||
| class GPTKeywordTableJIEBAQuery(BaseGPTKeywordTableQuery): | |||
| """GPT Keyword Table Index JIEBA Query. | |||
| Extracts keywords using JIEBA keyword extractor. | |||
| Set when `mode="jieba"` in `query` method of `GPTKeywordTableIndex`. | |||
| .. code-block:: python | |||
| response = index.query("<query_str>", mode="jieba") | |||
| See BaseGPTKeywordTableQuery for arguments. | |||
| """ | |||
| @classmethod | |||
| def from_args( | |||
| cls, | |||
| index_struct: IS, | |||
| service_context: ServiceContext, | |||
| docstore: Optional[BaseDocumentStore] = None, | |||
| node_postprocessors: Optional[List[BaseNodePostprocessor]] = None, | |||
| verbose: bool = False, | |||
| # response synthesizer args | |||
| response_mode: ResponseMode = ResponseMode.DEFAULT, | |||
| text_qa_template: Optional[QuestionAnswerPrompt] = None, | |||
| refine_template: Optional[RefinePrompt] = None, | |||
| simple_template: Optional[SimpleInputPrompt] = None, | |||
| response_kwargs: Optional[Dict] = None, | |||
| use_async: bool = False, | |||
| streaming: bool = False, | |||
| optimizer: Optional[BaseTokenUsageOptimizer] = None, | |||
| # class-specific args | |||
| **kwargs: Any, | |||
| ) -> "BaseGPTIndexQuery": | |||
| response_synthesizer = EnhanceResponseSynthesizer.from_args( | |||
| service_context=service_context, | |||
| text_qa_template=text_qa_template, | |||
| refine_template=refine_template, | |||
| simple_template=simple_template, | |||
| response_mode=response_mode, | |||
| response_kwargs=response_kwargs, | |||
| use_async=use_async, | |||
| streaming=streaming, | |||
| optimizer=optimizer, | |||
| ) | |||
| return cls( | |||
| index_struct=index_struct, | |||
| service_context=service_context, | |||
| response_synthesizer=response_synthesizer, | |||
| docstore=docstore, | |||
| node_postprocessors=node_postprocessors, | |||
| verbose=verbose, | |||
| **kwargs, | |||
| ) | |||
| def _get_keywords(self, query_str: str) -> List[str]: | |||
| """Extract keywords.""" | |||
| return list( | |||
| jieba_extract_keywords(query_str, max_keywords=self.max_keywords_per_query) | |||
| ) | |||
| @@ -1,135 +0,0 @@ | |||
| import json | |||
| from typing import List, Optional | |||
| from llama_index import ServiceContext, LLMPredictor, OpenAIEmbedding | |||
| from llama_index.data_structs import KeywordTable, Node | |||
| from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex | |||
| from llama_index.indices.registry import load_index_struct_from_dict | |||
| from core.docstore.dataset_docstore import DatesetDocumentStore | |||
| from core.docstore.empty_docstore import EmptyDocumentStore | |||
| from core.index.index_builder import IndexBuilder | |||
| from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex | |||
| from core.llm.llm_builder import LLMBuilder | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment | |||
| class KeywordTableIndex: | |||
| def __init__(self, dataset: Dataset): | |||
| self._dataset = dataset | |||
| def add_nodes(self, nodes: List[Node]): | |||
| llm = LLMBuilder.to_llm( | |||
| tenant_id=self._dataset.tenant_id, | |||
| model_name='fake' | |||
| ) | |||
| service_context = ServiceContext.from_defaults( | |||
| llm_predictor=LLMPredictor(llm=llm), | |||
| embed_model=OpenAIEmbedding() | |||
| ) | |||
| dataset_keyword_table = self.get_keyword_table() | |||
| if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict: | |||
| index_struct = KeywordTable() | |||
| else: | |||
| index_struct_dict = dataset_keyword_table.keyword_table_dict | |||
| index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict) | |||
| # create index | |||
| index = GPTJIEBAKeywordTableIndex( | |||
| index_struct=index_struct, | |||
| docstore=EmptyDocumentStore(), | |||
| service_context=service_context | |||
| ) | |||
| for node in nodes: | |||
| keywords = index._extract_keywords(node.get_text()) | |||
| self.update_segment_keywords(node.doc_id, list(keywords)) | |||
| index._index_struct.add_node(list(keywords), node) | |||
| index_struct_dict = index.index_struct.to_dict() | |||
| if not dataset_keyword_table: | |||
| dataset_keyword_table = DatasetKeywordTable( | |||
| dataset_id=self._dataset.id, | |||
| keyword_table=json.dumps(index_struct_dict) | |||
| ) | |||
| db.session.add(dataset_keyword_table) | |||
| else: | |||
| dataset_keyword_table.keyword_table = json.dumps(index_struct_dict) | |||
| db.session.commit() | |||
| def del_nodes(self, node_ids: List[str]): | |||
| llm = LLMBuilder.to_llm( | |||
| tenant_id=self._dataset.tenant_id, | |||
| model_name='fake' | |||
| ) | |||
| service_context = ServiceContext.from_defaults( | |||
| llm_predictor=LLMPredictor(llm=llm), | |||
| embed_model=OpenAIEmbedding() | |||
| ) | |||
| dataset_keyword_table = self.get_keyword_table() | |||
| if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict: | |||
| return | |||
| else: | |||
| index_struct_dict = dataset_keyword_table.keyword_table_dict | |||
| index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict) | |||
| # create index | |||
| index = GPTJIEBAKeywordTableIndex( | |||
| index_struct=index_struct, | |||
| docstore=EmptyDocumentStore(), | |||
| service_context=service_context | |||
| ) | |||
| for node_id in node_ids: | |||
| index.delete(node_id) | |||
| index_struct_dict = index.index_struct.to_dict() | |||
| if not dataset_keyword_table: | |||
| dataset_keyword_table = DatasetKeywordTable( | |||
| dataset_id=self._dataset.id, | |||
| keyword_table=json.dumps(index_struct_dict) | |||
| ) | |||
| db.session.add(dataset_keyword_table) | |||
| else: | |||
| dataset_keyword_table.keyword_table = json.dumps(index_struct_dict) | |||
| db.session.commit() | |||
| @property | |||
| def query_index(self) -> Optional[BaseGPTKeywordTableIndex]: | |||
| docstore = DatesetDocumentStore( | |||
| dataset=self._dataset, | |||
| user_id=self._dataset.created_by, | |||
| embedding_model_name="text-embedding-ada-002" | |||
| ) | |||
| service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id) | |||
| dataset_keyword_table = self.get_keyword_table() | |||
| if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict: | |||
| return None | |||
| index_struct: KeywordTable = load_index_struct_from_dict(dataset_keyword_table.keyword_table_dict) | |||
| return GPTJIEBAKeywordTableIndex(index_struct=index_struct, docstore=docstore, service_context=service_context) | |||
| def get_keyword_table(self): | |||
| dataset_keyword_table = self._dataset.dataset_keyword_table | |||
| if dataset_keyword_table: | |||
| return dataset_keyword_table | |||
| return None | |||
| def update_segment_keywords(self, node_id: str, keywords: List[str]): | |||
| document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first() | |||
| if document_segment: | |||
| document_segment.keywords = keywords | |||
| db.session.commit() | |||
| @@ -0,0 +1,33 @@ | |||
| import re | |||
| from typing import Set | |||
| import jieba | |||
| from jieba.analyse import default_tfidf | |||
| from core.index.keyword_table_index.stopwords import STOPWORDS | |||
| class JiebaKeywordTableHandler: | |||
| def __init__(self): | |||
| default_tfidf.stop_words = STOPWORDS | |||
| def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> Set[str]: | |||
| """Extract keywords with JIEBA tfidf.""" | |||
| keywords = jieba.analyse.extract_tags( | |||
| sentence=text, | |||
| topK=max_keywords_per_chunk, | |||
| ) | |||
| return set(self._expand_tokens_with_subtokens(keywords)) | |||
| def _expand_tokens_with_subtokens(self, tokens: Set[str]) -> Set[str]: | |||
| """Get subtokens from a list of tokens., filtering for stopwords.""" | |||
| results = set() | |||
| for token in tokens: | |||
| results.add(token) | |||
| sub_tokens = re.findall(r"\w+", token) | |||
| if len(sub_tokens) > 1: | |||
| results.update({w for w in sub_tokens if w not in list(STOPWORDS)}) | |||
| return results | |||
| @@ -0,0 +1,238 @@ | |||
| import json | |||
| from collections import defaultdict | |||
| from typing import Any, List, Optional, Dict | |||
| from langchain.schema import Document, BaseRetriever | |||
| from pydantic import BaseModel, Field, Extra | |||
| from core.index.base import BaseIndex | |||
| from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, DocumentSegment, DatasetKeywordTable | |||
| class KeywordTableConfig(BaseModel): | |||
| max_keywords_per_chunk: int = 10 | |||
| class KeywordTableIndex(BaseIndex): | |||
| def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()): | |||
| super().__init__(dataset) | |||
| self._config = config | |||
| def create(self, texts: list[Document], **kwargs) -> BaseIndex: | |||
| keyword_table_handler = JiebaKeywordTableHandler() | |||
| keyword_table = {} | |||
| for text in texts: | |||
| keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) | |||
| self._update_segment_keywords(text.metadata['doc_id'], list(keywords)) | |||
| keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) | |||
| dataset_keyword_table = DatasetKeywordTable( | |||
| dataset_id=self.dataset.id, | |||
| keyword_table=json.dumps({ | |||
| '__type__': 'keyword_table', | |||
| '__data__': { | |||
| "index_id": self.dataset.id, | |||
| "summary": None, | |||
| "table": {} | |||
| } | |||
| }, cls=SetEncoder) | |||
| ) | |||
| db.session.add(dataset_keyword_table) | |||
| db.session.commit() | |||
| self._save_dataset_keyword_table(keyword_table) | |||
| return self | |||
| def add_texts(self, texts: list[Document], **kwargs): | |||
| keyword_table_handler = JiebaKeywordTableHandler() | |||
| keyword_table = self._get_dataset_keyword_table() | |||
| for text in texts: | |||
| keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) | |||
| self._update_segment_keywords(text.metadata['doc_id'], list(keywords)) | |||
| keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) | |||
| self._save_dataset_keyword_table(keyword_table) | |||
| def text_exists(self, id: str) -> bool: | |||
| keyword_table = self._get_dataset_keyword_table() | |||
| return id in set.union(*keyword_table.values()) | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| keyword_table = self._get_dataset_keyword_table() | |||
| keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) | |||
| self._save_dataset_keyword_table(keyword_table) | |||
| def delete_by_document_id(self, document_id: str): | |||
| # get segment ids by document_id | |||
| segments = db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.dataset_id == self.dataset.id, | |||
| DocumentSegment.document_id == document_id | |||
| ).all() | |||
| ids = [segment.id for segment in segments] | |||
| keyword_table = self._get_dataset_keyword_table() | |||
| keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) | |||
| self._save_dataset_keyword_table(keyword_table) | |||
| def get_retriever(self, **kwargs: Any) -> BaseRetriever: | |||
| return KeywordTableRetriever(index=self, **kwargs) | |||
| def search( | |||
| self, query: str, | |||
| **kwargs: Any | |||
| ) -> List[Document]: | |||
| keyword_table = self._get_dataset_keyword_table() | |||
| search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {} | |||
| k = search_kwargs.get('k') if search_kwargs.get('k') else 4 | |||
| sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k) | |||
| documents = [] | |||
| for chunk_index in sorted_chunk_indices: | |||
| segment = db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.dataset_id == self.dataset.id, | |||
| DocumentSegment.index_node_id == chunk_index | |||
| ).first() | |||
| if segment: | |||
| documents.append(Document( | |||
| page_content=segment.content, | |||
| metadata={ | |||
| "doc_id": chunk_index, | |||
| "document_id": segment.document_id, | |||
| "dataset_id": segment.dataset_id, | |||
| } | |||
| )) | |||
| return documents | |||
| def delete(self) -> None: | |||
| dataset_keyword_table = self.dataset.dataset_keyword_table | |||
| if dataset_keyword_table: | |||
| db.session.delete(dataset_keyword_table) | |||
| db.session.commit() | |||
| def _save_dataset_keyword_table(self, keyword_table): | |||
| keyword_table_dict = { | |||
| '__type__': 'keyword_table', | |||
| '__data__': { | |||
| "index_id": self.dataset.id, | |||
| "summary": None, | |||
| "table": keyword_table | |||
| } | |||
| } | |||
| self.dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder) | |||
| db.session.commit() | |||
| def _get_dataset_keyword_table(self) -> Optional[dict]: | |||
| dataset_keyword_table = self.dataset.dataset_keyword_table | |||
| if dataset_keyword_table: | |||
| if dataset_keyword_table.keyword_table_dict: | |||
| return dataset_keyword_table.keyword_table_dict['__data__']['table'] | |||
| else: | |||
| dataset_keyword_table = DatasetKeywordTable( | |||
| dataset_id=self.dataset.id, | |||
| keyword_table=json.dumps({ | |||
| '__type__': 'keyword_table', | |||
| '__data__': { | |||
| "index_id": self.dataset.id, | |||
| "summary": None, | |||
| "table": {} | |||
| } | |||
| }, cls=SetEncoder) | |||
| ) | |||
| db.session.add(dataset_keyword_table) | |||
| db.session.commit() | |||
| return {} | |||
| def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict: | |||
| for keyword in keywords: | |||
| if keyword not in keyword_table: | |||
| keyword_table[keyword] = set() | |||
| keyword_table[keyword].add(id) | |||
| return keyword_table | |||
| def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]) -> dict: | |||
| # get set of ids that correspond to node | |||
| node_idxs_to_delete = set(ids) | |||
| # delete node_idxs from keyword to node idxs mapping | |||
| keywords_to_delete = set() | |||
| for keyword, node_idxs in keyword_table.items(): | |||
| if node_idxs_to_delete.intersection(node_idxs): | |||
| keyword_table[keyword] = node_idxs.difference( | |||
| node_idxs_to_delete | |||
| ) | |||
| if not keyword_table[keyword]: | |||
| keywords_to_delete.add(keyword) | |||
| for keyword in keywords_to_delete: | |||
| del keyword_table[keyword] | |||
| return keyword_table | |||
| def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4): | |||
| keyword_table_handler = JiebaKeywordTableHandler() | |||
| keywords = keyword_table_handler.extract_keywords(query) | |||
| # go through text chunks in order of most matching keywords | |||
| chunk_indices_count: Dict[str, int] = defaultdict(int) | |||
| keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())] | |||
| for keyword in keywords: | |||
| for node_id in keyword_table[keyword]: | |||
| chunk_indices_count[node_id] += 1 | |||
| sorted_chunk_indices = sorted( | |||
| list(chunk_indices_count.keys()), | |||
| key=lambda x: chunk_indices_count[x], | |||
| reverse=True, | |||
| ) | |||
| return sorted_chunk_indices[: k] | |||
| def _update_segment_keywords(self, node_id: str, keywords: List[str]): | |||
| document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first() | |||
| if document_segment: | |||
| document_segment.keywords = keywords | |||
| db.session.commit() | |||
| class KeywordTableRetriever(BaseRetriever, BaseModel): | |||
| index: KeywordTableIndex | |||
| search_kwargs: dict = Field(default_factory=dict) | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| extra = Extra.forbid | |||
| arbitrary_types_allowed = True | |||
| def get_relevant_documents(self, query: str) -> List[Document]: | |||
| """Get documents relevant for a query. | |||
| Args: | |||
| query: string to find relevant documents for | |||
| Returns: | |||
| List of relevant documents | |||
| """ | |||
| return self.index.search(query, **self.search_kwargs) | |||
| async def aget_relevant_documents(self, query: str) -> List[Document]: | |||
| raise NotImplementedError("KeywordTableRetriever does not support async") | |||
| class SetEncoder(json.JSONEncoder): | |||
| def default(self, obj): | |||
| if isinstance(obj, set): | |||
| return list(obj) | |||
| return super().default(obj) | |||
| @@ -1,79 +0,0 @@ | |||
| from typing import ( | |||
| Any, | |||
| Dict, | |||
| Optional, Sequence, | |||
| ) | |||
| from llama_index.indices.response.response_synthesis import ResponseSynthesizer | |||
| from llama_index.indices.response.response_builder import ResponseMode, BaseResponseBuilder, get_response_builder | |||
| from llama_index.indices.service_context import ServiceContext | |||
| from llama_index.optimization.optimizer import BaseTokenUsageOptimizer | |||
| from llama_index.prompts.prompts import ( | |||
| QuestionAnswerPrompt, | |||
| RefinePrompt, | |||
| SimpleInputPrompt, | |||
| ) | |||
| from llama_index.types import RESPONSE_TEXT_TYPE | |||
| class EnhanceResponseSynthesizer(ResponseSynthesizer): | |||
| @classmethod | |||
| def from_args( | |||
| cls, | |||
| service_context: ServiceContext, | |||
| streaming: bool = False, | |||
| use_async: bool = False, | |||
| text_qa_template: Optional[QuestionAnswerPrompt] = None, | |||
| refine_template: Optional[RefinePrompt] = None, | |||
| simple_template: Optional[SimpleInputPrompt] = None, | |||
| response_mode: ResponseMode = ResponseMode.DEFAULT, | |||
| response_kwargs: Optional[Dict] = None, | |||
| optimizer: Optional[BaseTokenUsageOptimizer] = None, | |||
| ) -> "ResponseSynthesizer": | |||
| response_builder: Optional[BaseResponseBuilder] = None | |||
| if response_mode != ResponseMode.NO_TEXT: | |||
| if response_mode == 'no_synthesizer': | |||
| response_builder = NoSynthesizer( | |||
| service_context=service_context, | |||
| simple_template=simple_template, | |||
| streaming=streaming, | |||
| ) | |||
| else: | |||
| response_builder = get_response_builder( | |||
| service_context, | |||
| text_qa_template, | |||
| refine_template, | |||
| simple_template, | |||
| response_mode, | |||
| use_async=use_async, | |||
| streaming=streaming, | |||
| ) | |||
| return cls(response_builder, response_mode, response_kwargs, optimizer) | |||
| class NoSynthesizer(BaseResponseBuilder): | |||
| def __init__( | |||
| self, | |||
| service_context: ServiceContext, | |||
| simple_template: Optional[SimpleInputPrompt] = None, | |||
| streaming: bool = False, | |||
| ) -> None: | |||
| super().__init__(service_context, streaming) | |||
| async def aget_response( | |||
| self, | |||
| query_str: str, | |||
| text_chunks: Sequence[str], | |||
| prev_response: Optional[str] = None, | |||
| **response_kwargs: Any, | |||
| ) -> RESPONSE_TEXT_TYPE: | |||
| return "\n".join(text_chunks) | |||
| def get_response( | |||
| self, | |||
| query_str: str, | |||
| text_chunks: Sequence[str], | |||
| prev_response: Optional[str] = None, | |||
| **response_kwargs: Any, | |||
| ) -> RESPONSE_TEXT_TYPE: | |||
| return "\n".join(text_chunks) | |||
| @@ -1,22 +0,0 @@ | |||
| from pathlib import Path | |||
| from typing import Dict | |||
| from bs4 import BeautifulSoup | |||
| from llama_index.readers.file.base_parser import BaseParser | |||
| class HTMLParser(BaseParser): | |||
| """HTML parser.""" | |||
| def _init_parser(self) -> Dict: | |||
| """Init parser.""" | |||
| return {} | |||
| def parse_file(self, file: Path, errors: str = "ignore") -> str: | |||
| """Parse file.""" | |||
| with open(file, "rb") as fp: | |||
| soup = BeautifulSoup(fp, 'html.parser') | |||
| text = soup.get_text() | |||
| text = text.strip() if text else '' | |||
| return text | |||
| @@ -1,111 +0,0 @@ | |||
| """Markdown parser. | |||
| Contains parser for md files. | |||
| """ | |||
| import re | |||
| from pathlib import Path | |||
| from typing import Any, Dict, List, Optional, Tuple, Union, cast | |||
| from llama_index.readers.file.base_parser import BaseParser | |||
| class MarkdownParser(BaseParser): | |||
| """Markdown parser. | |||
| Extract text from markdown files. | |||
| Returns dictionary with keys as headers and values as the text between headers. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| *args: Any, | |||
| remove_hyperlinks: bool = True, | |||
| remove_images: bool = True, | |||
| **kwargs: Any, | |||
| ) -> None: | |||
| """Init params.""" | |||
| super().__init__(*args, **kwargs) | |||
| self._remove_hyperlinks = remove_hyperlinks | |||
| self._remove_images = remove_images | |||
| def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]: | |||
| """Convert a markdown file to a dictionary. | |||
| The keys are the headers and the values are the text under each header. | |||
| """ | |||
| markdown_tups: List[Tuple[Optional[str], str]] = [] | |||
| lines = markdown_text.split("\n") | |||
| current_header = None | |||
| current_text = "" | |||
| for line in lines: | |||
| header_match = re.match(r"^#+\s", line) | |||
| if header_match: | |||
| if current_header is not None: | |||
| markdown_tups.append((current_header, current_text)) | |||
| current_header = line | |||
| current_text = "" | |||
| else: | |||
| current_text += line + "\n" | |||
| markdown_tups.append((current_header, current_text)) | |||
| if current_header is not None: | |||
| # pass linting, assert keys are defined | |||
| markdown_tups = [ | |||
| (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) | |||
| for key, value in markdown_tups | |||
| ] | |||
| else: | |||
| markdown_tups = [ | |||
| (key, re.sub("\n", "", value)) for key, value in markdown_tups | |||
| ] | |||
| return markdown_tups | |||
| def remove_images(self, content: str) -> str: | |||
| """Get a dictionary of a markdown file from its path.""" | |||
| pattern = r"!{1}\[\[(.*)\]\]" | |||
| content = re.sub(pattern, "", content) | |||
| return content | |||
| def remove_hyperlinks(self, content: str) -> str: | |||
| """Get a dictionary of a markdown file from its path.""" | |||
| pattern = r"\[(.*?)\]\((.*?)\)" | |||
| content = re.sub(pattern, r"\1", content) | |||
| return content | |||
| def _init_parser(self) -> Dict: | |||
| """Initialize the parser with the config.""" | |||
| return {} | |||
| def parse_tups( | |||
| self, filepath: Path, errors: str = "ignore" | |||
| ) -> List[Tuple[Optional[str], str]]: | |||
| """Parse file into tuples.""" | |||
| with open(filepath, "r", encoding="utf-8") as f: | |||
| content = f.read() | |||
| if self._remove_hyperlinks: | |||
| content = self.remove_hyperlinks(content) | |||
| if self._remove_images: | |||
| content = self.remove_images(content) | |||
| markdown_tups = self.markdown_to_tups(content) | |||
| return markdown_tups | |||
| def parse_file( | |||
| self, filepath: Path, errors: str = "ignore" | |||
| ) -> Union[str, List[str]]: | |||
| """Parse file into string.""" | |||
| tups = self.parse_tups(filepath, errors=errors) | |||
| results = [] | |||
| # TODO: don't include headers right now | |||
| for header, value in tups: | |||
| if header is None: | |||
| results.append(value) | |||
| else: | |||
| results.append(f"\n\n{header}\n{value}") | |||
| return results | |||
| @@ -1,56 +0,0 @@ | |||
| from pathlib import Path | |||
| from typing import Dict | |||
| from flask import current_app | |||
| from llama_index.readers.file.base_parser import BaseParser | |||
| from pypdf import PdfReader | |||
| from extensions.ext_storage import storage | |||
| from models.model import UploadFile | |||
| class PDFParser(BaseParser): | |||
| """PDF parser.""" | |||
| def _init_parser(self) -> Dict: | |||
| """Init parser.""" | |||
| return {} | |||
| def parse_file(self, file: Path, errors: str = "ignore") -> str: | |||
| """Parse file.""" | |||
| if not current_app.config.get('PDF_PREVIEW', True): | |||
| return '' | |||
| plaintext_file_key = '' | |||
| plaintext_file_exists = False | |||
| if self._parser_config and 'upload_file' in self._parser_config and self._parser_config['upload_file']: | |||
| upload_file: UploadFile = self._parser_config['upload_file'] | |||
| if upload_file.hash: | |||
| plaintext_file_key = 'upload_files/' + upload_file.tenant_id + '/' + upload_file.hash + '.plaintext' | |||
| try: | |||
| text = storage.load(plaintext_file_key).decode('utf-8') | |||
| plaintext_file_exists = True | |||
| return text | |||
| except FileNotFoundError: | |||
| pass | |||
| text_list = [] | |||
| with open(file, "rb") as fp: | |||
| # Create a PDF object | |||
| pdf = PdfReader(fp) | |||
| # Get the number of pages in the PDF document | |||
| num_pages = len(pdf.pages) | |||
| # Iterate over every page | |||
| for page in range(num_pages): | |||
| # Extract the text from the page | |||
| page_text = pdf.pages[page].extract_text() | |||
| text_list.append(page_text) | |||
| text = "\n".join(text_list) | |||
| # save plaintext file for caching | |||
| if not plaintext_file_exists and plaintext_file_key: | |||
| storage.save(plaintext_file_key, text.encode('utf-8')) | |||
| return text | |||
| @@ -1,33 +0,0 @@ | |||
| from pathlib import Path | |||
| import json | |||
| from typing import Dict | |||
| from openpyxl import load_workbook | |||
| from llama_index.readers.file.base_parser import BaseParser | |||
| from flask import current_app | |||
| class XLSXParser(BaseParser): | |||
| """XLSX parser.""" | |||
| def _init_parser(self) -> Dict: | |||
| """Init parser""" | |||
| return {} | |||
| def parse_file(self, file: Path, errors: str = "ignore") -> str: | |||
| data = [] | |||
| keys = [] | |||
| with open(file, "r") as fp: | |||
| wb = load_workbook(filename=file, read_only=True) | |||
| # loop over all sheets | |||
| for sheet in wb: | |||
| for row in sheet.iter_rows(values_only=True): | |||
| if all(v is None for v in row): | |||
| continue | |||
| if keys == []: | |||
| keys = list(map(str, row)) | |||
| else: | |||
| row_dict = dict(zip(keys, row)) | |||
| row_dict = {k: v for k, v in row_dict.items() if v} | |||
| data.append(json.dumps(row_dict, ensure_ascii=False)) | |||
| return '\n\n'.join(data) | |||
| @@ -1,136 +0,0 @@ | |||
| import json | |||
| import logging | |||
| from typing import List, Optional | |||
| from llama_index.data_structs import Node | |||
| from requests import ReadTimeout | |||
| from sqlalchemy.exc import IntegrityError | |||
| from tenacity import retry, stop_after_attempt, retry_if_exception_type | |||
| from core.index.index_builder import IndexBuilder | |||
| from core.vector_store.base import BaseGPTVectorStoreIndex | |||
| from extensions.ext_vector_store import vector_store | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, Embedding | |||
| class VectorIndex: | |||
| def __init__(self, dataset: Dataset): | |||
| self._dataset = dataset | |||
| def add_nodes(self, nodes: List[Node], duplicate_check: bool = False): | |||
| if not self._dataset.index_struct_dict: | |||
| index_id = "Vector_index_" + self._dataset.id.replace("-", "_") | |||
| self._dataset.index_struct = json.dumps(vector_store.to_index_struct(index_id)) | |||
| db.session.commit() | |||
| service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id) | |||
| index = vector_store.get_index( | |||
| service_context=service_context, | |||
| index_struct=self._dataset.index_struct_dict | |||
| ) | |||
| if duplicate_check: | |||
| nodes = self._filter_duplicate_nodes(index, nodes) | |||
| embedding_queue_nodes = [] | |||
| embedded_nodes = [] | |||
| for node in nodes: | |||
| node_hash = node.doc_hash | |||
| # if node hash in cached embedding tables, use cached embedding | |||
| embedding = db.session.query(Embedding).filter_by(hash=node_hash).first() | |||
| if embedding: | |||
| node.embedding = embedding.get_embedding() | |||
| embedded_nodes.append(node) | |||
| else: | |||
| embedding_queue_nodes.append(node) | |||
| if embedding_queue_nodes: | |||
| embedding_results = index._get_node_embedding_results( | |||
| embedding_queue_nodes, | |||
| set(), | |||
| ) | |||
| # pre embed nodes for cached embedding | |||
| for embedding_result in embedding_results: | |||
| node = embedding_result.node | |||
| node.embedding = embedding_result.embedding | |||
| try: | |||
| embedding = Embedding(hash=node.doc_hash) | |||
| embedding.set_embedding(node.embedding) | |||
| db.session.add(embedding) | |||
| db.session.commit() | |||
| except IntegrityError: | |||
| db.session.rollback() | |||
| continue | |||
| except: | |||
| logging.exception('Failed to add embedding to db') | |||
| continue | |||
| embedded_nodes.append(node) | |||
| self.index_insert_nodes(index, embedded_nodes) | |||
| @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3)) | |||
| def index_insert_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]): | |||
| index.insert_nodes(nodes) | |||
| def del_nodes(self, node_ids: List[str]): | |||
| if not self._dataset.index_struct_dict: | |||
| return | |||
| service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id) | |||
| index = vector_store.get_index( | |||
| service_context=service_context, | |||
| index_struct=self._dataset.index_struct_dict | |||
| ) | |||
| for node_id in node_ids: | |||
| self.index_delete_node(index, node_id) | |||
| @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3)) | |||
| def index_delete_node(self, index: BaseGPTVectorStoreIndex, node_id: str): | |||
| index.delete_node(node_id) | |||
| def del_doc(self, doc_id: str): | |||
| if not self._dataset.index_struct_dict: | |||
| return | |||
| service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id) | |||
| index = vector_store.get_index( | |||
| service_context=service_context, | |||
| index_struct=self._dataset.index_struct_dict | |||
| ) | |||
| self.index_delete_doc(index, doc_id) | |||
| @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3)) | |||
| def index_delete_doc(self, index: BaseGPTVectorStoreIndex, doc_id: str): | |||
| index.delete(doc_id) | |||
| @property | |||
| def query_index(self) -> Optional[BaseGPTVectorStoreIndex]: | |||
| if not self._dataset.index_struct_dict: | |||
| return None | |||
| service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id) | |||
| return vector_store.get_index( | |||
| service_context=service_context, | |||
| index_struct=self._dataset.index_struct_dict | |||
| ) | |||
| def _filter_duplicate_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]) -> List[Node]: | |||
| for node in nodes: | |||
| node_id = node.doc_id | |||
| exists_duplicate_node = index.exists_by_node_id(node_id) | |||
| if exists_duplicate_node: | |||
| nodes.remove(node) | |||
| return nodes | |||
| @@ -0,0 +1,175 @@ | |||
| import json | |||
| import logging | |||
| from abc import abstractmethod | |||
| from typing import List, Any, cast | |||
| from langchain.embeddings.base import Embeddings | |||
| from langchain.schema import Document, BaseRetriever | |||
| from langchain.vectorstores import VectorStore | |||
| from weaviate import UnexpectedStatusCodeException | |||
| from core.index.base import BaseIndex | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, DocumentSegment | |||
| from models.dataset import Document as DatasetDocument | |||
| class BaseVectorIndex(BaseIndex): | |||
| def __init__(self, dataset: Dataset, embeddings: Embeddings): | |||
| super().__init__(dataset) | |||
| self._embeddings = embeddings | |||
| self._vector_store = None | |||
| def get_type(self) -> str: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def get_index_name(self, dataset: Dataset) -> str: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def to_index_struct(self) -> dict: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def _get_vector_store(self) -> VectorStore: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def _get_vector_store_class(self) -> type: | |||
| raise NotImplementedError | |||
| def search( | |||
| self, query: str, | |||
| **kwargs: Any | |||
| ) -> List[Document]: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity' | |||
| search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {} | |||
| if search_type == 'similarity_score_threshold': | |||
| score_threshold = search_kwargs.get("score_threshold") | |||
| if (score_threshold is None) or (not isinstance(score_threshold, float)): | |||
| search_kwargs['score_threshold'] = .0 | |||
| docs_with_similarity = vector_store.similarity_search_with_relevance_scores( | |||
| query, **search_kwargs | |||
| ) | |||
| docs = [] | |||
| for doc, similarity in docs_with_similarity: | |||
| doc.metadata['score'] = similarity | |||
| docs.append(doc) | |||
| return docs | |||
| # similarity k | |||
| # mmr k, fetch_k, lambda_mult | |||
| # similarity_score_threshold k | |||
| return vector_store.as_retriever( | |||
| search_type=search_type, | |||
| search_kwargs=search_kwargs | |||
| ).get_relevant_documents(query) | |||
| def get_retriever(self, **kwargs: Any) -> BaseRetriever: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| return vector_store.as_retriever(**kwargs) | |||
| def add_texts(self, texts: list[Document], **kwargs): | |||
| if self._is_origin(): | |||
| self.recreate_dataset(self.dataset) | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| if kwargs.get('duplicate_check', False): | |||
| texts = self._filter_duplicate_texts(texts) | |||
| uuids = self._get_uuids(texts) | |||
| vector_store.add_documents(texts, uuids=uuids) | |||
| def text_exists(self, id: str) -> bool: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| return vector_store.text_exists(id) | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| if self._is_origin(): | |||
| self.recreate_dataset(self.dataset) | |||
| return | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| for node_id in ids: | |||
| vector_store.del_text(node_id) | |||
| def delete(self) -> None: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| vector_store.delete() | |||
| def _is_origin(self): | |||
| return False | |||
| def recreate_dataset(self, dataset: Dataset): | |||
| logging.info(f"Recreating dataset {dataset.id}") | |||
| try: | |||
| self.delete() | |||
| except UnexpectedStatusCodeException as e: | |||
| if e.status_code != 400: | |||
| # 400 means index not exists | |||
| raise e | |||
| dataset_documents = db.session.query(DatasetDocument).filter( | |||
| DatasetDocument.dataset_id == dataset.id, | |||
| DatasetDocument.indexing_status == 'completed', | |||
| DatasetDocument.enabled == True, | |||
| DatasetDocument.archived == False, | |||
| ).all() | |||
| documents = [] | |||
| for dataset_document in dataset_documents: | |||
| segments = db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.document_id == dataset_document.id, | |||
| DocumentSegment.status == 'completed', | |||
| DocumentSegment.enabled == True | |||
| ).all() | |||
| for segment in segments: | |||
| document = Document( | |||
| page_content=segment.content, | |||
| metadata={ | |||
| "doc_id": segment.index_node_id, | |||
| "doc_hash": segment.index_node_hash, | |||
| "document_id": segment.document_id, | |||
| "dataset_id": segment.dataset_id, | |||
| } | |||
| ) | |||
| documents.append(document) | |||
| origin_index_struct = self.dataset.index_struct | |||
| self.dataset.index_struct = None | |||
| if documents: | |||
| try: | |||
| self.create(documents) | |||
| except Exception as e: | |||
| self.dataset.index_struct = origin_index_struct | |||
| raise e | |||
| dataset.index_struct = json.dumps(self.to_index_struct()) | |||
| db.session.commit() | |||
| self.dataset = dataset | |||
| logging.info(f"Dataset {dataset.id} recreate successfully.") | |||
| @@ -0,0 +1,116 @@ | |||
| import os | |||
| from typing import Optional, Any, List, cast | |||
| import qdrant_client | |||
| from langchain.embeddings.base import Embeddings | |||
| from langchain.schema import Document, BaseRetriever | |||
| from langchain.vectorstores import VectorStore | |||
| from pydantic import BaseModel | |||
| from core.index.base import BaseIndex | |||
| from core.index.vector_index.base import BaseVectorIndex | |||
| from core.vector_store.qdrant_vector_store import QdrantVectorStore | |||
| from models.dataset import Dataset | |||
| class QdrantConfig(BaseModel): | |||
| endpoint: str | |||
| api_key: Optional[str] | |||
| root_path: Optional[str] | |||
| def to_qdrant_params(self): | |||
| if self.endpoint and self.endpoint.startswith('path:'): | |||
| path = self.endpoint.replace('path:', '') | |||
| if not os.path.isabs(path): | |||
| path = os.path.join(self.root_path, path) | |||
| return { | |||
| 'path': path | |||
| } | |||
| else: | |||
| return { | |||
| 'url': self.endpoint, | |||
| 'api_key': self.api_key, | |||
| } | |||
| class QdrantVectorIndex(BaseVectorIndex): | |||
| def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings): | |||
| super().__init__(dataset, embeddings) | |||
| self._client_config = config | |||
| def get_type(self) -> str: | |||
| return 'qdrant' | |||
| def get_index_name(self, dataset: Dataset) -> str: | |||
| if self.dataset.index_struct_dict: | |||
| return self.dataset.index_struct_dict['vector_store']['collection_name'] | |||
| dataset_id = dataset.id | |||
| return "Index_" + dataset_id.replace("-", "_") | |||
| def to_index_struct(self) -> dict: | |||
| return { | |||
| "type": self.get_type(), | |||
| "vector_store": {"collection_name": self.get_index_name(self.dataset)} | |||
| } | |||
| def create(self, texts: list[Document], **kwargs) -> BaseIndex: | |||
| uuids = self._get_uuids(texts) | |||
| self._vector_store = QdrantVectorStore.from_documents( | |||
| texts, | |||
| self._embeddings, | |||
| collection_name=self.get_index_name(self.dataset), | |||
| ids=uuids, | |||
| content_payload_key='text', | |||
| **self._client_config.to_qdrant_params() | |||
| ) | |||
| return self | |||
| def _get_vector_store(self) -> VectorStore: | |||
| """Only for created index.""" | |||
| if self._vector_store: | |||
| return self._vector_store | |||
| client = qdrant_client.QdrantClient( | |||
| **self._client_config.to_qdrant_params() | |||
| ) | |||
| return QdrantVectorStore( | |||
| client=client, | |||
| collection_name=self.get_index_name(self.dataset), | |||
| embeddings=self._embeddings, | |||
| content_payload_key='text' | |||
| ) | |||
| def _get_vector_store_class(self) -> type: | |||
| return QdrantVectorStore | |||
| def delete_by_document_id(self, document_id: str): | |||
| if self._is_origin(): | |||
| self.recreate_dataset(self.dataset) | |||
| return | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| from qdrant_client.http import models | |||
| vector_store.del_texts(models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key="metadata.document_id", | |||
| match=models.MatchValue(value=document_id), | |||
| ), | |||
| ], | |||
| )) | |||
| def _is_origin(self): | |||
| if self.dataset.index_struct_dict: | |||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['collection_name'] | |||
| if class_prefix.startswith('Vector_'): | |||
| # original class_prefix | |||
| return True | |||
| return False | |||
| @@ -0,0 +1,69 @@ | |||
| import json | |||
| from flask import current_app | |||
| from langchain.embeddings.base import Embeddings | |||
| from core.index.vector_index.base import BaseVectorIndex | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, Document | |||
| class VectorIndex: | |||
| def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings): | |||
| self._dataset = dataset | |||
| self._embeddings = embeddings | |||
| self._vector_index = self._init_vector_index(dataset, config, embeddings) | |||
| def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings) -> BaseVectorIndex: | |||
| vector_type = config.get('VECTOR_STORE') | |||
| if self._dataset.index_struct_dict: | |||
| vector_type = self._dataset.index_struct_dict['type'] | |||
| if not vector_type: | |||
| raise ValueError(f"Vector store must be specified.") | |||
| if vector_type == "weaviate": | |||
| from core.index.vector_index.weaviate_vector_index import WeaviateVectorIndex, WeaviateConfig | |||
| return WeaviateVectorIndex( | |||
| dataset=dataset, | |||
| config=WeaviateConfig( | |||
| endpoint=config.get('WEAVIATE_ENDPOINT'), | |||
| api_key=config.get('WEAVIATE_API_KEY'), | |||
| batch_size=int(config.get('WEAVIATE_BATCH_SIZE')) | |||
| ), | |||
| embeddings=embeddings | |||
| ) | |||
| elif vector_type == "qdrant": | |||
| from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig | |||
| return QdrantVectorIndex( | |||
| dataset=dataset, | |||
| config=QdrantConfig( | |||
| endpoint=config.get('QDRANT_URL'), | |||
| api_key=config.get('QDRANT_API_KEY'), | |||
| root_path=current_app.root_path | |||
| ), | |||
| embeddings=embeddings | |||
| ) | |||
| else: | |||
| raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") | |||
| def add_texts(self, texts: list[Document], **kwargs): | |||
| if not self._dataset.index_struct_dict: | |||
| self._vector_index.create(texts, **kwargs) | |||
| self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct()) | |||
| db.session.commit() | |||
| return | |||
| self._vector_index.add_texts(texts, **kwargs) | |||
| def __getattr__(self, name): | |||
| if self._vector_index is not None: | |||
| method = getattr(self._vector_index, name) | |||
| if callable(method): | |||
| return method | |||
| raise AttributeError(f"'VectorIndex' object has no attribute '{name}'") | |||
| @@ -0,0 +1,132 @@ | |||
| from typing import Optional, cast | |||
| import weaviate | |||
| from langchain.embeddings.base import Embeddings | |||
| from langchain.schema import Document, BaseRetriever | |||
| from langchain.vectorstores import VectorStore | |||
| from pydantic import BaseModel, root_validator | |||
| from core.index.base import BaseIndex | |||
| from core.index.vector_index.base import BaseVectorIndex | |||
| from core.vector_store.weaviate_vector_store import WeaviateVectorStore | |||
| from models.dataset import Dataset | |||
| class WeaviateConfig(BaseModel): | |||
| endpoint: str | |||
| api_key: Optional[str] | |||
| batch_size: int = 100 | |||
| @root_validator() | |||
| def validate_config(cls, values: dict) -> dict: | |||
| if not values['endpoint']: | |||
| raise ValueError("config WEAVIATE_ENDPOINT is required") | |||
| return values | |||
| class WeaviateVectorIndex(BaseVectorIndex): | |||
| def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings): | |||
| super().__init__(dataset, embeddings) | |||
| self._client = self._init_client(config) | |||
| def _init_client(self, config: WeaviateConfig) -> weaviate.Client: | |||
| auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key) | |||
| weaviate.connect.connection.has_grpc = False | |||
| client = weaviate.Client( | |||
| url=config.endpoint, | |||
| auth_client_secret=auth_config, | |||
| timeout_config=(5, 60), | |||
| startup_period=None | |||
| ) | |||
| client.batch.configure( | |||
| # `batch_size` takes an `int` value to enable auto-batching | |||
| # (`None` is used for manual batching) | |||
| batch_size=config.batch_size, | |||
| # dynamically update the `batch_size` based on import speed | |||
| dynamic=True, | |||
| # `timeout_retries` takes an `int` value to retry on time outs | |||
| timeout_retries=3, | |||
| ) | |||
| return client | |||
| def get_type(self) -> str: | |||
| return 'weaviate' | |||
| def get_index_name(self, dataset: Dataset) -> str: | |||
| if self.dataset.index_struct_dict: | |||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||
| if not class_prefix.endswith('_Node'): | |||
| # original class_prefix | |||
| class_prefix += '_Node' | |||
| return class_prefix | |||
| dataset_id = dataset.id | |||
| return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' | |||
| def to_index_struct(self) -> dict: | |||
| return { | |||
| "type": self.get_type(), | |||
| "vector_store": {"class_prefix": self.get_index_name(self.dataset)} | |||
| } | |||
| def create(self, texts: list[Document], **kwargs) -> BaseIndex: | |||
| uuids = self._get_uuids(texts) | |||
| self._vector_store = WeaviateVectorStore.from_documents( | |||
| texts, | |||
| self._embeddings, | |||
| client=self._client, | |||
| index_name=self.get_index_name(self.dataset), | |||
| uuids=uuids, | |||
| by_text=False | |||
| ) | |||
| return self | |||
| def _get_vector_store(self) -> VectorStore: | |||
| """Only for created index.""" | |||
| if self._vector_store: | |||
| return self._vector_store | |||
| attributes = ['doc_id', 'dataset_id', 'document_id'] | |||
| if self._is_origin(): | |||
| attributes = ['doc_id'] | |||
| return WeaviateVectorStore( | |||
| client=self._client, | |||
| index_name=self.get_index_name(self.dataset), | |||
| text_key='text', | |||
| embedding=self._embeddings, | |||
| attributes=attributes, | |||
| by_text=False | |||
| ) | |||
| def _get_vector_store_class(self) -> type: | |||
| return WeaviateVectorStore | |||
| def delete_by_document_id(self, document_id: str): | |||
| if self._is_origin(): | |||
| self.recreate_dataset(self.dataset) | |||
| return | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| vector_store.del_texts({ | |||
| "operator": "Equal", | |||
| "path": ["document_id"], | |||
| "valueText": document_id | |||
| }) | |||
| def _is_origin(self): | |||
| if self.dataset.index_struct_dict: | |||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||
| if not class_prefix.endswith('_Node'): | |||
| # original class_prefix | |||
| return True | |||
| return False | |||
| @@ -1,35 +1,34 @@ | |||
| import datetime | |||
| import json | |||
| import logging | |||
| import re | |||
| import tempfile | |||
| import time | |||
| from pathlib import Path | |||
| from typing import Optional, List | |||
| import uuid | |||
| from typing import Optional, List, cast | |||
| from flask import current_app | |||
| from flask_login import current_user | |||
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |||
| from langchain.embeddings import OpenAIEmbeddings | |||
| from langchain.schema import Document | |||
| from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter | |||
| from llama_index import SimpleDirectoryReader | |||
| from llama_index.data_structs import Node | |||
| from llama_index.data_structs.node_v2 import DocumentRelationship | |||
| from llama_index.node_parser import SimpleNodeParser, NodeParser | |||
| from llama_index.readers.file.base import DEFAULT_FILE_EXTRACTOR | |||
| from llama_index.readers.file.markdown_parser import MarkdownParser | |||
| from core.data_source.notion import NotionPageReader | |||
| from core.index.readers.xlsx_parser import XLSXParser | |||
| from core.data_loader.file_extractor import FileExtractor | |||
| from core.data_loader.loader.notion import NotionLoader | |||
| from core.docstore.dataset_docstore import DatesetDocumentStore | |||
| from core.index.keyword_table_index import KeywordTableIndex | |||
| from core.index.readers.html_parser import HTMLParser | |||
| from core.index.readers.markdown_parser import MarkdownParser | |||
| from core.index.readers.pdf_parser import PDFParser | |||
| from core.index.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter | |||
| from core.index.vector_index import VectorIndex | |||
| from core.embedding.cached_embedding import CacheEmbedding | |||
| from core.index.index import IndexBuilder | |||
| from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig | |||
| from core.index.vector_index.vector_index import VectorIndex | |||
| from core.llm.error import ProviderTokenNotInitError | |||
| from core.llm.llm_builder import LLMBuilder | |||
| from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter | |||
| from core.llm.token_calculator import TokenCalculator | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from extensions.ext_storage import storage | |||
| from models.dataset import Document, Dataset, DocumentSegment, DatasetProcessRule | |||
| from libs import helper | |||
| from models.dataset import Document as DatasetDocument | |||
| from models.dataset import Dataset, DocumentSegment, DatasetProcessRule | |||
| from models.model import UploadFile | |||
| from models.source import DataSourceBinding | |||
| @@ -40,135 +39,171 @@ class IndexingRunner: | |||
| self.storage = storage | |||
| self.embedding_model_name = embedding_model_name | |||
| def run(self, documents: List[Document]): | |||
| def run(self, dataset_documents: List[DatasetDocument]): | |||
| """Run the indexing process.""" | |||
| for document in documents: | |||
| for dataset_document in dataset_documents: | |||
| try: | |||
| # get dataset | |||
| dataset = Dataset.query.filter_by( | |||
| id=dataset_document.dataset_id | |||
| ).first() | |||
| if not dataset: | |||
| raise ValueError("no dataset found") | |||
| # load file | |||
| text_docs = self._load_data(dataset_document) | |||
| # get the process rule | |||
| processing_rule = db.session.query(DatasetProcessRule). \ | |||
| filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ | |||
| first() | |||
| # get splitter | |||
| splitter = self._get_splitter(processing_rule) | |||
| # split to documents | |||
| documents = self._step_split( | |||
| text_docs=text_docs, | |||
| splitter=splitter, | |||
| dataset=dataset, | |||
| dataset_document=dataset_document, | |||
| processing_rule=processing_rule | |||
| ) | |||
| # build index | |||
| self._build_index( | |||
| dataset=dataset, | |||
| dataset_document=dataset_document, | |||
| documents=documents | |||
| ) | |||
| except DocumentIsPausedException: | |||
| raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) | |||
| except ProviderTokenNotInitError as e: | |||
| dataset_document.indexing_status = 'error' | |||
| dataset_document.error = str(e.description) | |||
| dataset_document.stopped_at = datetime.datetime.utcnow() | |||
| db.session.commit() | |||
| except Exception as e: | |||
| logging.exception("consume document failed") | |||
| dataset_document.indexing_status = 'error' | |||
| dataset_document.error = str(e) | |||
| dataset_document.stopped_at = datetime.datetime.utcnow() | |||
| db.session.commit() | |||
| def run_in_splitting_status(self, dataset_document: DatasetDocument): | |||
| """Run the indexing process when the index_status is splitting.""" | |||
| try: | |||
| # get dataset | |||
| dataset = Dataset.query.filter_by( | |||
| id=document.dataset_id | |||
| id=dataset_document.dataset_id | |||
| ).first() | |||
| if not dataset: | |||
| raise ValueError("no dataset found") | |||
| # get exist document_segment list and delete | |||
| document_segments = DocumentSegment.query.filter_by( | |||
| dataset_id=dataset.id, | |||
| document_id=dataset_document.id | |||
| ).all() | |||
| db.session.delete(document_segments) | |||
| db.session.commit() | |||
| # load file | |||
| text_docs = self._load_data(document) | |||
| text_docs = self._load_data(dataset_document) | |||
| # get the process rule | |||
| processing_rule = db.session.query(DatasetProcessRule). \ | |||
| filter(DatasetProcessRule.id == document.dataset_process_rule_id). \ | |||
| filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ | |||
| first() | |||
| # get node parser for splitting | |||
| node_parser = self._get_node_parser(processing_rule) | |||
| # get splitter | |||
| splitter = self._get_splitter(processing_rule) | |||
| # split to nodes | |||
| nodes = self._step_split( | |||
| # split to documents | |||
| documents = self._step_split( | |||
| text_docs=text_docs, | |||
| node_parser=node_parser, | |||
| splitter=splitter, | |||
| dataset=dataset, | |||
| document=document, | |||
| dataset_document=dataset_document, | |||
| processing_rule=processing_rule | |||
| ) | |||
| # build index | |||
| self._build_index( | |||
| dataset=dataset, | |||
| document=document, | |||
| nodes=nodes | |||
| dataset_document=dataset_document, | |||
| documents=documents | |||
| ) | |||
| except DocumentIsPausedException: | |||
| raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) | |||
| except ProviderTokenNotInitError as e: | |||
| dataset_document.indexing_status = 'error' | |||
| dataset_document.error = str(e.description) | |||
| dataset_document.stopped_at = datetime.datetime.utcnow() | |||
| db.session.commit() | |||
| except Exception as e: | |||
| logging.exception("consume document failed") | |||
| dataset_document.indexing_status = 'error' | |||
| dataset_document.error = str(e) | |||
| dataset_document.stopped_at = datetime.datetime.utcnow() | |||
| db.session.commit() | |||
| def run_in_splitting_status(self, document: Document): | |||
| """Run the indexing process when the index_status is splitting.""" | |||
| # get dataset | |||
| dataset = Dataset.query.filter_by( | |||
| id=document.dataset_id | |||
| ).first() | |||
| if not dataset: | |||
| raise ValueError("no dataset found") | |||
| # get exist document_segment list and delete | |||
| document_segments = DocumentSegment.query.filter_by( | |||
| dataset_id=dataset.id, | |||
| document_id=document.id | |||
| ).all() | |||
| db.session.delete(document_segments) | |||
| db.session.commit() | |||
| # load file | |||
| text_docs = self._load_data(document) | |||
| # get the process rule | |||
| processing_rule = db.session.query(DatasetProcessRule). \ | |||
| filter(DatasetProcessRule.id == document.dataset_process_rule_id). \ | |||
| first() | |||
| # get node parser for splitting | |||
| node_parser = self._get_node_parser(processing_rule) | |||
| def run_in_indexing_status(self, dataset_document: DatasetDocument): | |||
| """Run the indexing process when the index_status is indexing.""" | |||
| try: | |||
| # get dataset | |||
| dataset = Dataset.query.filter_by( | |||
| id=dataset_document.dataset_id | |||
| ).first() | |||
| # split to nodes | |||
| nodes = self._step_split( | |||
| text_docs=text_docs, | |||
| node_parser=node_parser, | |||
| dataset=dataset, | |||
| document=document, | |||
| processing_rule=processing_rule | |||
| ) | |||
| if not dataset: | |||
| raise ValueError("no dataset found") | |||
| # build index | |||
| self._build_index( | |||
| dataset=dataset, | |||
| document=document, | |||
| nodes=nodes | |||
| ) | |||
| # get exist document_segment list and delete | |||
| document_segments = DocumentSegment.query.filter_by( | |||
| dataset_id=dataset.id, | |||
| document_id=dataset_document.id | |||
| ).all() | |||
| documents = [] | |||
| if document_segments: | |||
| for document_segment in document_segments: | |||
| # transform segment to node | |||
| if document_segment.status != "completed": | |||
| document = Document( | |||
| page_content=document_segment.content, | |||
| metadata={ | |||
| "doc_id": document_segment.index_node_id, | |||
| "doc_hash": document_segment.index_node_hash, | |||
| "document_id": document_segment.document_id, | |||
| "dataset_id": document_segment.dataset_id, | |||
| } | |||
| ) | |||
| documents.append(document) | |||
| def run_in_indexing_status(self, document: Document): | |||
| """Run the indexing process when the index_status is indexing.""" | |||
| # get dataset | |||
| dataset = Dataset.query.filter_by( | |||
| id=document.dataset_id | |||
| ).first() | |||
| if not dataset: | |||
| raise ValueError("no dataset found") | |||
| # get exist document_segment list and delete | |||
| document_segments = DocumentSegment.query.filter_by( | |||
| dataset_id=dataset.id, | |||
| document_id=document.id | |||
| ).all() | |||
| nodes = [] | |||
| if document_segments: | |||
| for document_segment in document_segments: | |||
| # transform segment to node | |||
| if document_segment.status != "completed": | |||
| relationships = { | |||
| DocumentRelationship.SOURCE: document_segment.document_id, | |||
| } | |||
| previous_segment = document_segment.previous_segment | |||
| if previous_segment: | |||
| relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id | |||
| next_segment = document_segment.next_segment | |||
| if next_segment: | |||
| relationships[DocumentRelationship.NEXT] = next_segment.index_node_id | |||
| node = Node( | |||
| doc_id=document_segment.index_node_id, | |||
| doc_hash=document_segment.index_node_hash, | |||
| text=document_segment.content, | |||
| extra_info=None, | |||
| node_info=None, | |||
| relationships=relationships | |||
| ) | |||
| nodes.append(node) | |||
| # build index | |||
| self._build_index( | |||
| dataset=dataset, | |||
| document=document, | |||
| nodes=nodes | |||
| ) | |||
| # build index | |||
| self._build_index( | |||
| dataset=dataset, | |||
| dataset_document=dataset_document, | |||
| documents=documents | |||
| ) | |||
| except DocumentIsPausedException: | |||
| raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) | |||
| except ProviderTokenNotInitError as e: | |||
| dataset_document.indexing_status = 'error' | |||
| dataset_document.error = str(e.description) | |||
| dataset_document.stopped_at = datetime.datetime.utcnow() | |||
| db.session.commit() | |||
| except Exception as e: | |||
| logging.exception("consume document failed") | |||
| dataset_document.indexing_status = 'error' | |||
| dataset_document.error = str(e) | |||
| dataset_document.stopped_at = datetime.datetime.utcnow() | |||
| db.session.commit() | |||
| def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict: | |||
| """ | |||
| @@ -179,28 +214,28 @@ class IndexingRunner: | |||
| total_segments = 0 | |||
| for file_detail in file_details: | |||
| # load data from file | |||
| text_docs = self._load_data_from_file(file_detail) | |||
| text_docs = FileExtractor.load(file_detail) | |||
| processing_rule = DatasetProcessRule( | |||
| mode=tmp_processing_rule["mode"], | |||
| rules=json.dumps(tmp_processing_rule["rules"]) | |||
| ) | |||
| # get node parser for splitting | |||
| node_parser = self._get_node_parser(processing_rule) | |||
| # get splitter | |||
| splitter = self._get_splitter(processing_rule) | |||
| # split to nodes | |||
| nodes = self._split_to_nodes( | |||
| # split to documents | |||
| documents = self._split_to_documents( | |||
| text_docs=text_docs, | |||
| node_parser=node_parser, | |||
| splitter=splitter, | |||
| processing_rule=processing_rule | |||
| ) | |||
| total_segments += len(nodes) | |||
| for node in nodes: | |||
| total_segments += len(documents) | |||
| for document in documents: | |||
| if len(preview_texts) < 5: | |||
| preview_texts.append(node.get_text()) | |||
| preview_texts.append(document.page_content) | |||
| tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) | |||
| tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content) | |||
| return { | |||
| "total_segments": total_segments, | |||
| @@ -230,35 +265,36 @@ class IndexingRunner: | |||
| ).first() | |||
| if not data_source_binding: | |||
| raise ValueError('Data source binding not found.') | |||
| reader = NotionPageReader(integration_token=data_source_binding.access_token) | |||
| for page in notion_info['pages']: | |||
| if page['type'] == 'page': | |||
| page_ids = [page['page_id']] | |||
| documents = reader.load_data_as_documents(page_ids=page_ids) | |||
| elif page['type'] == 'database': | |||
| documents = reader.load_data_as_documents(database_id=page['page_id']) | |||
| else: | |||
| documents = [] | |||
| loader = NotionLoader( | |||
| notion_access_token=data_source_binding.access_token, | |||
| notion_workspace_id=workspace_id, | |||
| notion_obj_id=page['page_id'], | |||
| notion_page_type=page['type'] | |||
| ) | |||
| documents = loader.load() | |||
| processing_rule = DatasetProcessRule( | |||
| mode=tmp_processing_rule["mode"], | |||
| rules=json.dumps(tmp_processing_rule["rules"]) | |||
| ) | |||
| # get node parser for splitting | |||
| node_parser = self._get_node_parser(processing_rule) | |||
| # get splitter | |||
| splitter = self._get_splitter(processing_rule) | |||
| # split to nodes | |||
| nodes = self._split_to_nodes( | |||
| # split to documents | |||
| documents = self._split_to_documents( | |||
| text_docs=documents, | |||
| node_parser=node_parser, | |||
| splitter=splitter, | |||
| processing_rule=processing_rule | |||
| ) | |||
| total_segments += len(nodes) | |||
| for node in nodes: | |||
| total_segments += len(documents) | |||
| for document in documents: | |||
| if len(preview_texts) < 5: | |||
| preview_texts.append(node.get_text()) | |||
| preview_texts.append(document.page_content) | |||
| tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) | |||
| tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content) | |||
| return { | |||
| "total_segments": total_segments, | |||
| @@ -268,14 +304,14 @@ class IndexingRunner: | |||
| "preview": preview_texts | |||
| } | |||
| def _load_data(self, document: Document) -> List[Document]: | |||
| def _load_data(self, dataset_document: DatasetDocument) -> List[Document]: | |||
| # load file | |||
| if document.data_source_type not in ["upload_file", "notion_import"]: | |||
| if dataset_document.data_source_type not in ["upload_file", "notion_import"]: | |||
| return [] | |||
| data_source_info = document.data_source_info_dict | |||
| data_source_info = dataset_document.data_source_info_dict | |||
| text_docs = [] | |||
| if document.data_source_type == 'upload_file': | |||
| if dataset_document.data_source_type == 'upload_file': | |||
| if not data_source_info or 'upload_file_id' not in data_source_info: | |||
| raise ValueError("no upload file found") | |||
| @@ -283,47 +319,28 @@ class IndexingRunner: | |||
| filter(UploadFile.id == data_source_info['upload_file_id']). \ | |||
| one_or_none() | |||
| text_docs = self._load_data_from_file(file_detail) | |||
| elif document.data_source_type == 'notion_import': | |||
| if not data_source_info or 'notion_page_id' not in data_source_info \ | |||
| or 'notion_workspace_id' not in data_source_info: | |||
| raise ValueError("no notion page found") | |||
| workspace_id = data_source_info['notion_workspace_id'] | |||
| page_id = data_source_info['notion_page_id'] | |||
| page_type = data_source_info['type'] | |||
| data_source_binding = DataSourceBinding.query.filter( | |||
| db.and_( | |||
| DataSourceBinding.tenant_id == document.tenant_id, | |||
| DataSourceBinding.provider == 'notion', | |||
| DataSourceBinding.disabled == False, | |||
| DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' | |||
| ) | |||
| ).first() | |||
| if not data_source_binding: | |||
| raise ValueError('Data source binding not found.') | |||
| if page_type == 'page': | |||
| # add page last_edited_time to data_source_info | |||
| self._get_notion_page_last_edited_time(page_id, data_source_binding.access_token, document) | |||
| text_docs = self._load_page_data_from_notion(page_id, data_source_binding.access_token) | |||
| elif page_type == 'database': | |||
| # add page last_edited_time to data_source_info | |||
| self._get_notion_database_last_edited_time(page_id, data_source_binding.access_token, document) | |||
| text_docs = self._load_database_data_from_notion(page_id, data_source_binding.access_token) | |||
| text_docs = FileExtractor.load(file_detail) | |||
| elif dataset_document.data_source_type == 'notion_import': | |||
| loader = NotionLoader.from_document(dataset_document) | |||
| text_docs = loader.load() | |||
| # update document status to splitting | |||
| self._update_document_index_status( | |||
| document_id=document.id, | |||
| document_id=dataset_document.id, | |||
| after_indexing_status="splitting", | |||
| extra_update_params={ | |||
| Document.word_count: sum([len(text_doc.text) for text_doc in text_docs]), | |||
| Document.parsing_completed_at: datetime.datetime.utcnow() | |||
| DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]), | |||
| DatasetDocument.parsing_completed_at: datetime.datetime.utcnow() | |||
| } | |||
| ) | |||
| # replace doc id to document model id | |||
| text_docs = cast(List[Document], text_docs) | |||
| for text_doc in text_docs: | |||
| # remove invalid symbol | |||
| text_doc.text = self.filter_string(text_doc.get_text()) | |||
| text_doc.doc_id = document.id | |||
| text_doc.page_content = self.filter_string(text_doc.page_content) | |||
| text_doc.metadata['document_id'] = dataset_document.id | |||
| text_doc.metadata['dataset_id'] = dataset_document.dataset_id | |||
| return text_docs | |||
| @@ -331,61 +348,7 @@ class IndexingRunner: | |||
| pattern = re.compile('[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]') | |||
| return pattern.sub('', text) | |||
| def _load_data_from_file(self, upload_file: UploadFile) -> List[Document]: | |||
| with tempfile.TemporaryDirectory() as temp_dir: | |||
| suffix = Path(upload_file.key).suffix | |||
| filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" | |||
| self.storage.download(upload_file.key, filepath) | |||
| file_extractor = DEFAULT_FILE_EXTRACTOR.copy() | |||
| file_extractor[".markdown"] = MarkdownParser() | |||
| file_extractor[".md"] = MarkdownParser() | |||
| file_extractor[".html"] = HTMLParser() | |||
| file_extractor[".htm"] = HTMLParser() | |||
| file_extractor[".pdf"] = PDFParser({'upload_file': upload_file}) | |||
| file_extractor[".xlsx"] = XLSXParser() | |||
| loader = SimpleDirectoryReader(input_files=[filepath], file_extractor=file_extractor) | |||
| text_docs = loader.load_data() | |||
| return text_docs | |||
| def _load_page_data_from_notion(self, page_id: str, access_token: str) -> List[Document]: | |||
| page_ids = [page_id] | |||
| reader = NotionPageReader(integration_token=access_token) | |||
| text_docs = reader.load_data_as_documents(page_ids=page_ids) | |||
| return text_docs | |||
| def _load_database_data_from_notion(self, database_id: str, access_token: str) -> List[Document]: | |||
| reader = NotionPageReader(integration_token=access_token) | |||
| text_docs = reader.load_data_as_documents(database_id=database_id) | |||
| return text_docs | |||
| def _get_notion_page_last_edited_time(self, page_id: str, access_token: str, document: Document): | |||
| reader = NotionPageReader(integration_token=access_token) | |||
| last_edited_time = reader.get_page_last_edited_time(page_id) | |||
| data_source_info = document.data_source_info_dict | |||
| data_source_info['last_edited_time'] = last_edited_time | |||
| update_params = { | |||
| Document.data_source_info: json.dumps(data_source_info) | |||
| } | |||
| Document.query.filter_by(id=document.id).update(update_params) | |||
| db.session.commit() | |||
| def _get_notion_database_last_edited_time(self, page_id: str, access_token: str, document: Document): | |||
| reader = NotionPageReader(integration_token=access_token) | |||
| last_edited_time = reader.get_database_last_edited_time(page_id) | |||
| data_source_info = document.data_source_info_dict | |||
| data_source_info['last_edited_time'] = last_edited_time | |||
| update_params = { | |||
| Document.data_source_info: json.dumps(data_source_info) | |||
| } | |||
| Document.query.filter_by(id=document.id).update(update_params) | |||
| db.session.commit() | |||
| def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser: | |||
| def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter: | |||
| """ | |||
| Get the NodeParser object according to the processing rule. | |||
| """ | |||
| @@ -414,68 +377,83 @@ class IndexingRunner: | |||
| separators=["\n\n", "。", ".", " ", ""] | |||
| ) | |||
| return SimpleNodeParser(text_splitter=character_splitter, include_extra_info=True) | |||
| return character_splitter | |||
| def _step_split(self, text_docs: List[Document], node_parser: NodeParser, | |||
| dataset: Dataset, document: Document, processing_rule: DatasetProcessRule) -> List[Node]: | |||
| def _step_split(self, text_docs: List[Document], splitter: TextSplitter, | |||
| dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \ | |||
| -> List[Document]: | |||
| """ | |||
| Split the text documents into nodes and save them to the document segment. | |||
| Split the text documents into documents and save them to the document segment. | |||
| """ | |||
| nodes = self._split_to_nodes( | |||
| documents = self._split_to_documents( | |||
| text_docs=text_docs, | |||
| node_parser=node_parser, | |||
| splitter=splitter, | |||
| processing_rule=processing_rule | |||
| ) | |||
| # save node to document segment | |||
| doc_store = DatesetDocumentStore( | |||
| dataset=dataset, | |||
| user_id=document.created_by, | |||
| user_id=dataset_document.created_by, | |||
| embedding_model_name=self.embedding_model_name, | |||
| document_id=document.id | |||
| document_id=dataset_document.id | |||
| ) | |||
| # add document segments | |||
| doc_store.add_documents(nodes) | |||
| doc_store.add_documents(documents) | |||
| # update document status to indexing | |||
| cur_time = datetime.datetime.utcnow() | |||
| self._update_document_index_status( | |||
| document_id=document.id, | |||
| document_id=dataset_document.id, | |||
| after_indexing_status="indexing", | |||
| extra_update_params={ | |||
| Document.cleaning_completed_at: cur_time, | |||
| Document.splitting_completed_at: cur_time, | |||
| DatasetDocument.cleaning_completed_at: cur_time, | |||
| DatasetDocument.splitting_completed_at: cur_time, | |||
| } | |||
| ) | |||
| # update segment status to indexing | |||
| self._update_segments_by_document( | |||
| document_id=document.id, | |||
| dataset_document_id=dataset_document.id, | |||
| update_params={ | |||
| DocumentSegment.status: "indexing", | |||
| DocumentSegment.indexing_at: datetime.datetime.utcnow() | |||
| } | |||
| ) | |||
| return nodes | |||
| return documents | |||
| def _split_to_nodes(self, text_docs: List[Document], node_parser: NodeParser, | |||
| processing_rule: DatasetProcessRule) -> List[Node]: | |||
| def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter, | |||
| processing_rule: DatasetProcessRule) -> List[Document]: | |||
| """ | |||
| Split the text documents into nodes. | |||
| """ | |||
| all_nodes = [] | |||
| all_documents = [] | |||
| for text_doc in text_docs: | |||
| # document clean | |||
| document_text = self._document_clean(text_doc.get_text(), processing_rule) | |||
| text_doc.text = document_text | |||
| document_text = self._document_clean(text_doc.page_content, processing_rule) | |||
| text_doc.page_content = document_text | |||
| # parse document to nodes | |||
| nodes = node_parser.get_nodes_from_documents([text_doc]) | |||
| nodes = [node for node in nodes if node.text is not None and node.text.strip()] | |||
| all_nodes.extend(nodes) | |||
| documents = splitter.split_documents([text_doc]) | |||
| split_documents = [] | |||
| for document in documents: | |||
| if document.page_content is None or not document.page_content.strip(): | |||
| continue | |||
| doc_id = str(uuid.uuid4()) | |||
| hash = helper.generate_text_hash(document.page_content) | |||
| document.metadata['doc_id'] = doc_id | |||
| document.metadata['doc_hash'] = hash | |||
| split_documents.append(document) | |||
| all_documents.extend(split_documents) | |||
| return all_nodes | |||
| return all_documents | |||
| def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str: | |||
| """ | |||
| @@ -506,37 +484,38 @@ class IndexingRunner: | |||
| return text | |||
| def _build_index(self, dataset: Dataset, document: Document, nodes: List[Node]) -> None: | |||
| def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None: | |||
| """ | |||
| Build the index for the document. | |||
| """ | |||
| vector_index = VectorIndex(dataset=dataset) | |||
| keyword_table_index = KeywordTableIndex(dataset=dataset) | |||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| keyword_table_index = IndexBuilder.get_index(dataset, 'economy') | |||
| # chunk nodes by chunk size | |||
| indexing_start_at = time.perf_counter() | |||
| tokens = 0 | |||
| chunk_size = 100 | |||
| for i in range(0, len(nodes), chunk_size): | |||
| for i in range(0, len(documents), chunk_size): | |||
| # check document is paused | |||
| self._check_document_paused_status(document.id) | |||
| chunk_nodes = nodes[i:i + chunk_size] | |||
| self._check_document_paused_status(dataset_document.id) | |||
| chunk_documents = documents[i:i + chunk_size] | |||
| tokens += sum( | |||
| TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) for node in chunk_nodes | |||
| TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content) | |||
| for document in chunk_documents | |||
| ) | |||
| # save vector index | |||
| if dataset.indexing_technique == "high_quality": | |||
| vector_index.add_nodes(chunk_nodes) | |||
| if vector_index: | |||
| vector_index.add_texts(chunk_documents) | |||
| # save keyword index | |||
| keyword_table_index.add_nodes(chunk_nodes) | |||
| keyword_table_index.add_texts(chunk_documents) | |||
| node_ids = [node.doc_id for node in chunk_nodes] | |||
| document_ids = [document.metadata['doc_id'] for document in chunk_documents] | |||
| db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.document_id == document.id, | |||
| DocumentSegment.index_node_id.in_(node_ids), | |||
| DocumentSegment.document_id == dataset_document.id, | |||
| DocumentSegment.index_node_id.in_(document_ids), | |||
| DocumentSegment.status == "indexing" | |||
| ).update({ | |||
| DocumentSegment.status: "completed", | |||
| @@ -549,12 +528,12 @@ class IndexingRunner: | |||
| # update document status to completed | |||
| self._update_document_index_status( | |||
| document_id=document.id, | |||
| document_id=dataset_document.id, | |||
| after_indexing_status="completed", | |||
| extra_update_params={ | |||
| Document.tokens: tokens, | |||
| Document.completed_at: datetime.datetime.utcnow(), | |||
| Document.indexing_latency: indexing_end_at - indexing_start_at, | |||
| DatasetDocument.tokens: tokens, | |||
| DatasetDocument.completed_at: datetime.datetime.utcnow(), | |||
| DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at, | |||
| } | |||
| ) | |||
| @@ -569,25 +548,25 @@ class IndexingRunner: | |||
| """ | |||
| Update the document indexing status. | |||
| """ | |||
| count = Document.query.filter_by(id=document_id, is_paused=True).count() | |||
| count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count() | |||
| if count > 0: | |||
| raise DocumentIsPausedException() | |||
| update_params = { | |||
| Document.indexing_status: after_indexing_status | |||
| DatasetDocument.indexing_status: after_indexing_status | |||
| } | |||
| if extra_update_params: | |||
| update_params.update(extra_update_params) | |||
| Document.query.filter_by(id=document_id).update(update_params) | |||
| DatasetDocument.query.filter_by(id=document_id).update(update_params) | |||
| db.session.commit() | |||
| def _update_segments_by_document(self, document_id: str, update_params: dict) -> None: | |||
| def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None: | |||
| """ | |||
| Update the document segment by document id. | |||
| """ | |||
| DocumentSegment.query.filter_by(document_id=document_id).update(update_params) | |||
| DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) | |||
| db.session.commit() | |||
| @@ -1,7 +1,6 @@ | |||
| from typing import Union, Optional | |||
| from typing import Union, Optional, List | |||
| from langchain.callbacks import CallbackManager | |||
| from langchain.llms.fake import FakeListLLM | |||
| from langchain.callbacks.base import BaseCallbackHandler | |||
| from core.constant import llm_constant | |||
| from core.llm.error import ProviderTokenNotInitError | |||
| @@ -32,12 +31,11 @@ class LLMBuilder: | |||
| """ | |||
| @classmethod | |||
| def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI, FakeListLLM]: | |||
| if model_name == 'fake': | |||
| return FakeListLLM(responses=[]) | |||
| def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]: | |||
| provider = cls.get_default_provider(tenant_id) | |||
| model_credentials = cls.get_model_credentials(tenant_id, provider, model_name) | |||
| mode = cls.get_mode_by_model(model_name) | |||
| if mode == 'chat': | |||
| if provider == 'openai': | |||
| @@ -52,16 +50,21 @@ class LLMBuilder: | |||
| else: | |||
| raise ValueError(f"model name {model_name} is not supported.") | |||
| model_credentials = cls.get_model_credentials(tenant_id, provider, model_name) | |||
| model_kwargs = { | |||
| 'top_p': kwargs.get('top_p', 1), | |||
| 'frequency_penalty': kwargs.get('frequency_penalty', 0), | |||
| 'presence_penalty': kwargs.get('presence_penalty', 0), | |||
| } | |||
| model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs} | |||
| return llm_cls( | |||
| model_name=model_name, | |||
| temperature=kwargs.get('temperature', 0), | |||
| max_tokens=kwargs.get('max_tokens', 256), | |||
| top_p=kwargs.get('top_p', 1), | |||
| frequency_penalty=kwargs.get('frequency_penalty', 0), | |||
| presence_penalty=kwargs.get('presence_penalty', 0), | |||
| callback_manager=kwargs.get('callback_manager', None), | |||
| **model_extras_kwargs, | |||
| callbacks=kwargs.get('callbacks', None), | |||
| streaming=kwargs.get('streaming', False), | |||
| # request_timeout=None | |||
| **model_credentials | |||
| @@ -69,7 +72,7 @@ class LLMBuilder: | |||
| @classmethod | |||
| def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False, | |||
| callback_manager: Optional[CallbackManager] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]: | |||
| callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]: | |||
| model_name = model.get("name") | |||
| completion_params = model.get("completion_params", {}) | |||
| @@ -82,7 +85,7 @@ class LLMBuilder: | |||
| frequency_penalty=completion_params.get('frequency_penalty', 0.1), | |||
| presence_penalty=completion_params.get('presence_penalty', 0.1), | |||
| streaming=streaming, | |||
| callback_manager=callback_manager | |||
| callbacks=callbacks | |||
| ) | |||
| @classmethod | |||
| @@ -42,7 +42,10 @@ class AzureProvider(BaseProvider): | |||
| """ | |||
| config = self.get_provider_api_key(model_id=model_id) | |||
| config['openai_api_type'] = 'azure' | |||
| config['deployment_name'] = model_id.replace('.', '') if model_id else None | |||
| if model_id == 'text-embedding-ada-002': | |||
| config['deployment'] = model_id.replace('.', '') if model_id else None | |||
| else: | |||
| config['deployment_name'] = model_id.replace('.', '') if model_id else None | |||
| return config | |||
| def get_provider_name(self): | |||
| @@ -1,3 +1,4 @@ | |||
| from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks | |||
| from langchain.schema import BaseMessage, ChatResult, LLMResult | |||
| from langchain.chat_models import AzureChatOpenAI | |||
| from typing import Optional, List, Dict, Any | |||
| @@ -68,60 +69,22 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI): | |||
| return message_tokens | |||
| def _generate( | |||
| self, messages: List[BaseMessage], stop: Optional[List[str]] = None | |||
| ) -> ChatResult: | |||
| self.callback_manager.on_llm_start( | |||
| {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], | |||
| verbose=self.verbose | |||
| ) | |||
| chat_result = super()._generate(messages, stop) | |||
| result = LLMResult( | |||
| generations=[chat_result.generations], | |||
| llm_output=chat_result.llm_output | |||
| ) | |||
| self.callback_manager.on_llm_end(result, verbose=self.verbose) | |||
| return chat_result | |||
| async def _agenerate( | |||
| self, messages: List[BaseMessage], stop: Optional[List[str]] = None | |||
| ) -> ChatResult: | |||
| if self.callback_manager.is_async: | |||
| await self.callback_manager.on_llm_start( | |||
| {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], | |||
| verbose=self.verbose | |||
| ) | |||
| else: | |||
| self.callback_manager.on_llm_start( | |||
| {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], | |||
| verbose=self.verbose | |||
| ) | |||
| chat_result = super()._generate(messages, stop) | |||
| result = LLMResult( | |||
| generations=[chat_result.generations], | |||
| llm_output=chat_result.llm_output | |||
| ) | |||
| if self.callback_manager.is_async: | |||
| await self.callback_manager.on_llm_end(result, verbose=self.verbose) | |||
| else: | |||
| self.callback_manager.on_llm_end(result, verbose=self.verbose) | |||
| return chat_result | |||
| @handle_llm_exceptions | |||
| def generate( | |||
| self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None | |||
| self, | |||
| messages: List[List[BaseMessage]], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> LLMResult: | |||
| return super().generate(messages, stop) | |||
| return super().generate(messages, stop, callbacks, **kwargs) | |||
| @handle_llm_exceptions_async | |||
| async def agenerate( | |||
| self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None | |||
| self, | |||
| messages: List[List[BaseMessage]], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> LLMResult: | |||
| return await super().agenerate(messages, stop) | |||
| return await super().agenerate(messages, stop, callbacks, **kwargs) | |||
| @@ -1,5 +1,4 @@ | |||
| import os | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.llms import AzureOpenAI | |||
| from langchain.schema import LLMResult | |||
| from typing import Optional, List, Dict, Mapping, Any | |||
| @@ -53,12 +52,20 @@ class StreamableAzureOpenAI(AzureOpenAI): | |||
| @handle_llm_exceptions | |||
| def generate( | |||
| self, prompts: List[str], stop: Optional[List[str]] = None | |||
| self, | |||
| prompts: List[str], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> LLMResult: | |||
| return super().generate(prompts, stop) | |||
| return super().generate(prompts, stop, callbacks, **kwargs) | |||
| @handle_llm_exceptions_async | |||
| async def agenerate( | |||
| self, prompts: List[str], stop: Optional[List[str]] = None | |||
| self, | |||
| prompts: List[str], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> LLMResult: | |||
| return await super().agenerate(prompts, stop) | |||
| return await super().agenerate(prompts, stop, callbacks, **kwargs) | |||
| @@ -1,6 +1,7 @@ | |||
| import os | |||
| from langchain.schema import BaseMessage, ChatResult, LLMResult | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.schema import BaseMessage, LLMResult | |||
| from langchain.chat_models import ChatOpenAI | |||
| from typing import Optional, List, Dict, Any | |||
| @@ -70,57 +71,22 @@ class StreamableChatOpenAI(ChatOpenAI): | |||
| return message_tokens | |||
| def _generate( | |||
| self, messages: List[BaseMessage], stop: Optional[List[str]] = None | |||
| ) -> ChatResult: | |||
| self.callback_manager.on_llm_start( | |||
| {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose | |||
| ) | |||
| chat_result = super()._generate(messages, stop) | |||
| result = LLMResult( | |||
| generations=[chat_result.generations], | |||
| llm_output=chat_result.llm_output | |||
| ) | |||
| self.callback_manager.on_llm_end(result, verbose=self.verbose) | |||
| return chat_result | |||
| async def _agenerate( | |||
| self, messages: List[BaseMessage], stop: Optional[List[str]] = None | |||
| ) -> ChatResult: | |||
| if self.callback_manager.is_async: | |||
| await self.callback_manager.on_llm_start( | |||
| {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose | |||
| ) | |||
| else: | |||
| self.callback_manager.on_llm_start( | |||
| {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose | |||
| ) | |||
| chat_result = super()._generate(messages, stop) | |||
| result = LLMResult( | |||
| generations=[chat_result.generations], | |||
| llm_output=chat_result.llm_output | |||
| ) | |||
| if self.callback_manager.is_async: | |||
| await self.callback_manager.on_llm_end(result, verbose=self.verbose) | |||
| else: | |||
| self.callback_manager.on_llm_end(result, verbose=self.verbose) | |||
| return chat_result | |||
| @handle_llm_exceptions | |||
| def generate( | |||
| self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None | |||
| self, | |||
| messages: List[List[BaseMessage]], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> LLMResult: | |||
| return super().generate(messages, stop) | |||
| return super().generate(messages, stop, callbacks, **kwargs) | |||
| @handle_llm_exceptions_async | |||
| async def agenerate( | |||
| self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None | |||
| self, | |||
| messages: List[List[BaseMessage]], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> LLMResult: | |||
| return await super().agenerate(messages, stop) | |||
| return await super().agenerate(messages, stop, callbacks, **kwargs) | |||
| @@ -1,5 +1,6 @@ | |||
| import os | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.schema import LLMResult | |||
| from typing import Optional, List, Dict, Any, Mapping | |||
| from langchain import OpenAI | |||
| @@ -48,15 +49,22 @@ class StreamableOpenAI(OpenAI): | |||
| "organization": self.openai_organization if self.openai_organization else None, | |||
| }} | |||
| @handle_llm_exceptions | |||
| def generate( | |||
| self, prompts: List[str], stop: Optional[List[str]] = None | |||
| self, | |||
| prompts: List[str], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> LLMResult: | |||
| return super().generate(prompts, stop) | |||
| return super().generate(prompts, stop, callbacks, **kwargs) | |||
| @handle_llm_exceptions_async | |||
| async def agenerate( | |||
| self, prompts: List[str], stop: Optional[List[str]] = None | |||
| self, | |||
| prompts: List[str], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> LLMResult: | |||
| return await super().agenerate(prompts, stop) | |||
| return await super().agenerate(prompts, stop, callbacks, **kwargs) | |||
| @@ -1,7 +1,7 @@ | |||
| from typing import Any, List, Dict | |||
| from langchain.memory.chat_memory import BaseChatMemory | |||
| from langchain.schema import get_buffer_string, BaseMessage, BaseLanguageModel | |||
| from langchain.schema import get_buffer_string | |||
| from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ | |||
| ReadOnlyConversationTokenDBBufferSharedMemory | |||
| @@ -1,5 +1,3 @@ | |||
| from llama_index import QueryKeywordExtractPrompt | |||
| CONVERSATION_TITLE_PROMPT = ( | |||
| "Human:{query}\n-----\n" | |||
| "Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n" | |||
| @@ -45,23 +43,6 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( | |||
| "[\"question1\",\"question2\",\"question3\"]\n" | |||
| ) | |||
| QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL = ( | |||
| "A question is provided below. Given the question, extract up to {max_keywords} " | |||
| "keywords from the text. Focus on extracting the keywords that we can use " | |||
| "to best lookup answers to the question. Avoid stopwords." | |||
| "I am not sure which language the following question is in. " | |||
| "If the user asked the question in Chinese, please return the keywords in Chinese. " | |||
| "If the user asked the question in English, please return the keywords in English.\n" | |||
| "---------------------\n" | |||
| "{question}\n" | |||
| "---------------------\n" | |||
| "Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'\n" | |||
| ) | |||
| QUERY_KEYWORD_EXTRACT_TEMPLATE = QueryKeywordExtractPrompt( | |||
| QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL | |||
| ) | |||
| RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \ | |||
| the model prompt that best suits the input. | |||
| You will be provided with the prompt, variables, and an opening statement. | |||
| @@ -0,0 +1,87 @@ | |||
| from flask import current_app | |||
| from langchain.embeddings import OpenAIEmbeddings | |||
| from langchain.tools import BaseTool | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| from core.embedding.cached_embedding import CacheEmbedding | |||
| from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig | |||
| from core.index.vector_index.vector_index import VectorIndex | |||
| from core.llm.llm_builder import LLMBuilder | |||
| from models.dataset import Dataset | |||
| class DatasetTool(BaseTool): | |||
| """Tool for querying a Dataset.""" | |||
| dataset: Dataset | |||
| k: int = 2 | |||
| def _run(self, tool_input: str) -> str: | |||
| if self.dataset.indexing_technique == "economy": | |||
| # use keyword table query | |||
| kw_table_index = KeywordTableIndex( | |||
| dataset=self.dataset, | |||
| config=KeywordTableConfig( | |||
| max_keywords_per_chunk=5 | |||
| ) | |||
| ) | |||
| documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k}) | |||
| else: | |||
| model_credentials = LLMBuilder.get_model_credentials( | |||
| tenant_id=self.dataset.tenant_id, | |||
| model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id), | |||
| model_name='text-embedding-ada-002' | |||
| ) | |||
| embeddings = CacheEmbedding(OpenAIEmbeddings( | |||
| **model_credentials | |||
| )) | |||
| vector_index = VectorIndex( | |||
| dataset=self.dataset, | |||
| config=current_app.config, | |||
| embeddings=embeddings | |||
| ) | |||
| documents = vector_index.search( | |||
| tool_input, | |||
| search_type='similarity', | |||
| search_kwargs={ | |||
| 'k': self.k | |||
| } | |||
| ) | |||
| hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id) | |||
| hit_callback.on_tool_end(documents) | |||
| return str("\n".join([document.page_content for document in documents])) | |||
| async def _arun(self, tool_input: str) -> str: | |||
| model_credentials = LLMBuilder.get_model_credentials( | |||
| tenant_id=self.dataset.tenant_id, | |||
| model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id), | |||
| model_name='text-embedding-ada-002' | |||
| ) | |||
| embeddings = CacheEmbedding(OpenAIEmbeddings( | |||
| **model_credentials | |||
| )) | |||
| vector_index = VectorIndex( | |||
| dataset=self.dataset, | |||
| config=current_app.config, | |||
| embeddings=embeddings | |||
| ) | |||
| documents = await vector_index.asearch( | |||
| tool_input, | |||
| search_type='similarity', | |||
| search_kwargs={ | |||
| 'k': 10 | |||
| } | |||
| ) | |||
| hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id) | |||
| hit_callback.on_tool_end(documents) | |||
| return str("\n".join([document.page_content for document in documents])) | |||
| @@ -1,73 +0,0 @@ | |||
| from typing import Optional | |||
| from langchain.callbacks import CallbackManager | |||
| from llama_index.langchain_helpers.agents import IndexToolConfig | |||
| from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler | |||
| from core.index.keyword_table_index import KeywordTableIndex | |||
| from core.index.vector_index import VectorIndex | |||
| from core.prompt.prompts import QUERY_KEYWORD_EXTRACT_TEMPLATE | |||
| from core.tool.llama_index_tool import EnhanceLlamaIndexTool | |||
| from models.dataset import Dataset | |||
| class DatasetToolBuilder: | |||
| @classmethod | |||
| def build_dataset_tool(cls, dataset: Dataset, | |||
| response_mode: str = "no_synthesizer", | |||
| callback_handler: Optional[DatasetToolCallbackHandler] = None): | |||
| if dataset.indexing_technique == "economy": | |||
| # use keyword table query | |||
| index = KeywordTableIndex(dataset=dataset).query_index | |||
| if not index: | |||
| return None | |||
| query_kwargs = { | |||
| "mode": "default", | |||
| "response_mode": response_mode, | |||
| "query_keyword_extract_template": QUERY_KEYWORD_EXTRACT_TEMPLATE, | |||
| "max_keywords_per_query": 5, | |||
| # If num_chunks_per_query is too large, | |||
| # it will slow down the synthesis process due to multiple iterations of refinement. | |||
| "num_chunks_per_query": 2 | |||
| } | |||
| else: | |||
| index = VectorIndex(dataset=dataset).query_index | |||
| if not index: | |||
| return None | |||
| query_kwargs = { | |||
| "mode": "default", | |||
| "response_mode": response_mode, | |||
| # If top_k is too large, | |||
| # it will slow down the synthesis process due to multiple iterations of refinement. | |||
| "similarity_top_k": 2 | |||
| } | |||
| # fulfill description when it is empty | |||
| description = dataset.description | |||
| if not description: | |||
| description = 'useful for when you want to answer queries about the ' + dataset.name | |||
| index_tool_config = IndexToolConfig( | |||
| index=index, | |||
| name=f"dataset-{dataset.id}", | |||
| description=description, | |||
| index_query_kwargs=query_kwargs, | |||
| tool_kwargs={ | |||
| "callback_manager": CallbackManager([callback_handler, DifyStdOutCallbackHandler()]) | |||
| }, | |||
| # tool_kwargs={"return_direct": True}, | |||
| # return_direct: Whether to return LLM results directly or process the output data with an Output Parser | |||
| ) | |||
| index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset.id) | |||
| return EnhanceLlamaIndexTool.from_tool_config( | |||
| tool_config=index_tool_config, | |||
| callback_handler=index_callback_handler | |||
| ) | |||
| @@ -1,43 +0,0 @@ | |||
| from typing import Dict | |||
| from langchain.tools import BaseTool | |||
| from llama_index.indices.base import BaseGPTIndex | |||
| from llama_index.langchain_helpers.agents import IndexToolConfig | |||
| from pydantic import Field | |||
| from core.callback_handler.index_tool_callback_handler import IndexToolCallbackHandler | |||
| class EnhanceLlamaIndexTool(BaseTool): | |||
| """Tool for querying a LlamaIndex.""" | |||
| # NOTE: name/description still needs to be set | |||
| index: BaseGPTIndex | |||
| query_kwargs: Dict = Field(default_factory=dict) | |||
| return_sources: bool = False | |||
| callback_handler: IndexToolCallbackHandler | |||
| @classmethod | |||
| def from_tool_config(cls, tool_config: IndexToolConfig, | |||
| callback_handler: IndexToolCallbackHandler) -> "EnhanceLlamaIndexTool": | |||
| """Create a tool from a tool config.""" | |||
| return_sources = tool_config.tool_kwargs.pop("return_sources", False) | |||
| return cls( | |||
| index=tool_config.index, | |||
| callback_handler=callback_handler, | |||
| name=tool_config.name, | |||
| description=tool_config.description, | |||
| return_sources=return_sources, | |||
| query_kwargs=tool_config.index_query_kwargs, | |||
| **tool_config.tool_kwargs, | |||
| ) | |||
| def _run(self, tool_input: str) -> str: | |||
| response = self.index.query(tool_input, **self.query_kwargs) | |||
| self.callback_handler.on_tool_end(response) | |||
| return str(response) | |||
| async def _arun(self, tool_input: str) -> str: | |||
| response = await self.index.aquery(tool_input, **self.query_kwargs) | |||
| self.callback_handler.on_tool_end(response) | |||
| return str(response) | |||
| @@ -1,34 +0,0 @@ | |||
| from abc import ABC, abstractmethod | |||
| from typing import Optional | |||
| from llama_index import ServiceContext, GPTVectorStoreIndex | |||
| from llama_index.data_structs import Node | |||
| from llama_index.vector_stores.types import VectorStore | |||
| class BaseVectorStoreClient(ABC): | |||
| @abstractmethod | |||
| def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def to_index_config(self, index_id: str) -> dict: | |||
| raise NotImplementedError | |||
| class BaseGPTVectorStoreIndex(GPTVectorStoreIndex): | |||
| def delete_node(self, node_id: str): | |||
| self._vector_store.delete_node(node_id) | |||
| def exists_by_node_id(self, node_id: str) -> bool: | |||
| return self._vector_store.exists_by_node_id(node_id) | |||
| class EnhanceVectorStore(ABC): | |||
| @abstractmethod | |||
| def delete_node(self, node_id: str): | |||
| pass | |||
| @abstractmethod | |||
| def exists_by_node_id(self, node_id: str) -> bool: | |||
| pass | |||
| @@ -0,0 +1,69 @@ | |||
| from typing import cast, Any | |||
| from langchain.schema import Document | |||
| from langchain.vectorstores import Qdrant | |||
| from qdrant_client.http.models import Filter, PointIdsList, FilterSelector | |||
| from qdrant_client.local.qdrant_local import QdrantLocal | |||
| class QdrantVectorStore(Qdrant): | |||
| def del_texts(self, filter: Filter): | |||
| if not filter: | |||
| raise ValueError('filter must not be empty') | |||
| self._reload_if_needed() | |||
| self.client.delete( | |||
| collection_name=self.collection_name, | |||
| points_selector=FilterSelector( | |||
| filter=filter | |||
| ), | |||
| ) | |||
| def del_text(self, uuid: str) -> None: | |||
| self._reload_if_needed() | |||
| self.client.delete( | |||
| collection_name=self.collection_name, | |||
| points_selector=PointIdsList( | |||
| points=[uuid], | |||
| ), | |||
| ) | |||
| def text_exists(self, uuid: str) -> bool: | |||
| self._reload_if_needed() | |||
| response = self.client.retrieve( | |||
| collection_name=self.collection_name, | |||
| ids=[uuid] | |||
| ) | |||
| return len(response) > 0 | |||
| def delete(self): | |||
| self._reload_if_needed() | |||
| self.client.delete_collection(collection_name=self.collection_name) | |||
| @classmethod | |||
| def _document_from_scored_point( | |||
| cls, | |||
| scored_point: Any, | |||
| content_payload_key: str, | |||
| metadata_payload_key: str, | |||
| ) -> Document: | |||
| if scored_point.payload.get('doc_id'): | |||
| return Document( | |||
| page_content=scored_point.payload.get(content_payload_key), | |||
| metadata={'doc_id': scored_point.id} | |||
| ) | |||
| return Document( | |||
| page_content=scored_point.payload.get(content_payload_key), | |||
| metadata=scored_point.payload.get(metadata_payload_key) or {}, | |||
| ) | |||
| def _reload_if_needed(self): | |||
| if isinstance(self.client, QdrantLocal): | |||
| self.client = cast(QdrantLocal, self.client) | |||
| self.client._load() | |||
| @@ -1,147 +0,0 @@ | |||
| import os | |||
| from typing import cast, List | |||
| from llama_index.data_structs import Node | |||
| from llama_index.data_structs.node_v2 import DocumentRelationship | |||
| from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult | |||
| from qdrant_client.http.models import Payload, Filter | |||
| import qdrant_client | |||
| from llama_index import ServiceContext, GPTVectorStoreIndex, GPTQdrantIndex | |||
| from llama_index.data_structs.data_structs_v2 import QdrantIndexDict | |||
| from llama_index.vector_stores import QdrantVectorStore | |||
| from qdrant_client.local.qdrant_local import QdrantLocal | |||
| from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore | |||
| class QdrantVectorStoreClient(BaseVectorStoreClient): | |||
| def __init__(self, url: str, api_key: str, root_path: str): | |||
| self._client = self.init_from_config(url, api_key, root_path) | |||
| @classmethod | |||
| def init_from_config(cls, url: str, api_key: str, root_path: str): | |||
| if url and url.startswith('path:'): | |||
| path = url.replace('path:', '') | |||
| if not os.path.isabs(path): | |||
| path = os.path.join(root_path, path) | |||
| return qdrant_client.QdrantClient( | |||
| path=path | |||
| ) | |||
| else: | |||
| return qdrant_client.QdrantClient( | |||
| url=url, | |||
| api_key=api_key, | |||
| ) | |||
| def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex: | |||
| index_struct = QdrantIndexDict() | |||
| if self._client is None: | |||
| raise Exception("Vector client is not initialized.") | |||
| # {"collection_name": "Gpt_index_xxx"} | |||
| collection_name = config.get('collection_name') | |||
| if not collection_name: | |||
| raise Exception("collection_name cannot be None.") | |||
| return GPTQdrantEnhanceIndex( | |||
| service_context=service_context, | |||
| index_struct=index_struct, | |||
| vector_store=QdrantEnhanceVectorStore( | |||
| client=self._client, | |||
| collection_name=collection_name | |||
| ) | |||
| ) | |||
| def to_index_config(self, index_id: str) -> dict: | |||
| return {"collection_name": index_id} | |||
| class GPTQdrantEnhanceIndex(GPTQdrantIndex, BaseGPTVectorStoreIndex): | |||
| pass | |||
| class QdrantEnhanceVectorStore(QdrantVectorStore, EnhanceVectorStore): | |||
| def delete_node(self, node_id: str): | |||
| """ | |||
| Delete node from the index. | |||
| :param node_id: node id | |||
| """ | |||
| from qdrant_client.http import models as rest | |||
| self._reload_if_needed() | |||
| self._client.delete( | |||
| collection_name=self._collection_name, | |||
| points_selector=rest.Filter( | |||
| must=[ | |||
| rest.FieldCondition( | |||
| key="id", match=rest.MatchValue(value=node_id) | |||
| ) | |||
| ] | |||
| ), | |||
| ) | |||
| def exists_by_node_id(self, node_id: str) -> bool: | |||
| """ | |||
| Get node from the index by node id. | |||
| :param node_id: node id | |||
| """ | |||
| self._reload_if_needed() | |||
| response = self._client.retrieve( | |||
| collection_name=self._collection_name, | |||
| ids=[node_id] | |||
| ) | |||
| return len(response) > 0 | |||
| def query( | |||
| self, | |||
| query: VectorStoreQuery, | |||
| ) -> VectorStoreQueryResult: | |||
| """Query index for top k most similar nodes. | |||
| Args: | |||
| query (VectorStoreQuery): query | |||
| """ | |||
| query_embedding = cast(List[float], query.query_embedding) | |||
| self._reload_if_needed() | |||
| response = self._client.search( | |||
| collection_name=self._collection_name, | |||
| query_vector=query_embedding, | |||
| limit=cast(int, query.similarity_top_k), | |||
| query_filter=cast(Filter, self._build_query_filter(query)), | |||
| with_vectors=True | |||
| ) | |||
| nodes = [] | |||
| similarities = [] | |||
| ids = [] | |||
| for point in response: | |||
| payload = cast(Payload, point.payload) | |||
| node = Node( | |||
| doc_id=str(point.id), | |||
| text=payload.get("text"), | |||
| embedding=point.vector, | |||
| extra_info=payload.get("extra_info"), | |||
| relationships={ | |||
| DocumentRelationship.SOURCE: payload.get("doc_id", "None"), | |||
| }, | |||
| ) | |||
| nodes.append(node) | |||
| similarities.append(point.score) | |||
| ids.append(str(point.id)) | |||
| return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids) | |||
| def _reload_if_needed(self): | |||
| if isinstance(self._client._client, QdrantLocal): | |||
| self._client._client._load() | |||
| @@ -1,62 +0,0 @@ | |||
| from flask import Flask | |||
| from llama_index import ServiceContext, GPTVectorStoreIndex | |||
| from requests import ReadTimeout | |||
| from tenacity import retry, retry_if_exception_type, stop_after_attempt | |||
| from core.vector_store.qdrant_vector_store_client import QdrantVectorStoreClient | |||
| from core.vector_store.weaviate_vector_store_client import WeaviateVectorStoreClient | |||
| SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant'] | |||
| class VectorStore: | |||
| def __init__(self): | |||
| self._vector_store = None | |||
| self._client = None | |||
| def init_app(self, app: Flask): | |||
| if not app.config['VECTOR_STORE']: | |||
| return | |||
| self._vector_store = app.config['VECTOR_STORE'] | |||
| if self._vector_store not in SUPPORTED_VECTOR_STORES: | |||
| raise ValueError(f"Vector store {self._vector_store} is not supported.") | |||
| if self._vector_store == 'weaviate': | |||
| self._client = WeaviateVectorStoreClient( | |||
| endpoint=app.config['WEAVIATE_ENDPOINT'], | |||
| api_key=app.config['WEAVIATE_API_KEY'], | |||
| grpc_enabled=app.config['WEAVIATE_GRPC_ENABLED'], | |||
| batch_size=app.config['WEAVIATE_BATCH_SIZE'] | |||
| ) | |||
| elif self._vector_store == 'qdrant': | |||
| self._client = QdrantVectorStoreClient( | |||
| url=app.config['QDRANT_URL'], | |||
| api_key=app.config['QDRANT_API_KEY'], | |||
| root_path=app.root_path | |||
| ) | |||
| app.extensions['vector_store'] = self | |||
| @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3)) | |||
| def get_index(self, service_context: ServiceContext, index_struct: dict) -> GPTVectorStoreIndex: | |||
| vector_store_config: dict = index_struct.get('vector_store') | |||
| index = self.get_client().get_index( | |||
| service_context=service_context, | |||
| config=vector_store_config | |||
| ) | |||
| return index | |||
| def to_index_struct(self, index_id: str) -> dict: | |||
| return { | |||
| "type": self._vector_store, | |||
| "vector_store": self.get_client().to_index_config(index_id) | |||
| } | |||
| def get_client(self): | |||
| if not self._client: | |||
| raise Exception("Vector store client is not initialized.") | |||
| return self._client | |||
| @@ -1,66 +0,0 @@ | |||
| from llama_index.indices.query.base import IS | |||
| from typing import ( | |||
| Any, | |||
| Dict, | |||
| List, | |||
| Optional | |||
| ) | |||
| from llama_index.docstore import BaseDocumentStore | |||
| from llama_index.indices.postprocessor.node import ( | |||
| BaseNodePostprocessor, | |||
| ) | |||
| from llama_index.indices.vector_store import GPTVectorStoreIndexQuery | |||
| from llama_index.indices.response.response_builder import ResponseMode | |||
| from llama_index.indices.service_context import ServiceContext | |||
| from llama_index.optimization.optimizer import BaseTokenUsageOptimizer | |||
| from llama_index.prompts.prompts import ( | |||
| QuestionAnswerPrompt, | |||
| RefinePrompt, | |||
| SimpleInputPrompt, | |||
| ) | |||
| from core.index.query.synthesizer import EnhanceResponseSynthesizer | |||
| class EnhanceGPTVectorStoreIndexQuery(GPTVectorStoreIndexQuery): | |||
| @classmethod | |||
| def from_args( | |||
| cls, | |||
| index_struct: IS, | |||
| service_context: ServiceContext, | |||
| docstore: Optional[BaseDocumentStore] = None, | |||
| node_postprocessors: Optional[List[BaseNodePostprocessor]] = None, | |||
| verbose: bool = False, | |||
| # response synthesizer args | |||
| response_mode: ResponseMode = ResponseMode.DEFAULT, | |||
| text_qa_template: Optional[QuestionAnswerPrompt] = None, | |||
| refine_template: Optional[RefinePrompt] = None, | |||
| simple_template: Optional[SimpleInputPrompt] = None, | |||
| response_kwargs: Optional[Dict] = None, | |||
| use_async: bool = False, | |||
| streaming: bool = False, | |||
| optimizer: Optional[BaseTokenUsageOptimizer] = None, | |||
| # class-specific args | |||
| **kwargs: Any, | |||
| ) -> "BaseGPTIndexQuery": | |||
| response_synthesizer = EnhanceResponseSynthesizer.from_args( | |||
| service_context=service_context, | |||
| text_qa_template=text_qa_template, | |||
| refine_template=refine_template, | |||
| simple_template=simple_template, | |||
| response_mode=response_mode, | |||
| response_kwargs=response_kwargs, | |||
| use_async=use_async, | |||
| streaming=streaming, | |||
| optimizer=optimizer, | |||
| ) | |||
| return cls( | |||
| index_struct=index_struct, | |||
| service_context=service_context, | |||
| response_synthesizer=response_synthesizer, | |||
| docstore=docstore, | |||
| node_postprocessors=node_postprocessors, | |||
| verbose=verbose, | |||
| **kwargs, | |||
| ) | |||
| @@ -0,0 +1,38 @@ | |||
| from langchain.vectorstores import Weaviate | |||
| class WeaviateVectorStore(Weaviate): | |||
| def del_texts(self, where_filter: dict): | |||
| if not where_filter: | |||
| raise ValueError('where_filter must not be empty') | |||
| self._client.batch.delete_objects( | |||
| class_name=self._index_name, | |||
| where=where_filter, | |||
| output='minimal' | |||
| ) | |||
| def del_text(self, uuid: str) -> None: | |||
| self._client.data_object.delete( | |||
| uuid, | |||
| class_name=self._index_name | |||
| ) | |||
| def text_exists(self, uuid: str) -> bool: | |||
| result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({ | |||
| "path": ["doc_id"], | |||
| "operator": "Equal", | |||
| "valueText": uuid, | |||
| }).with_limit(1).do() | |||
| if "errors" in result: | |||
| raise ValueError(f"Error during query: {result['errors']}") | |||
| entries = result["data"]["Get"][self._index_name] | |||
| if len(entries) == 0: | |||
| return False | |||
| return True | |||
| def delete(self): | |||
| self._client.schema.delete_class(self._index_name) | |||
| @@ -1,270 +0,0 @@ | |||
| import json | |||
| import weaviate | |||
| from dataclasses import field | |||
| from typing import List, Any, Dict, Optional | |||
| from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore | |||
| from llama_index import ServiceContext, GPTWeaviateIndex, GPTVectorStoreIndex | |||
| from llama_index.data_structs.data_structs_v2 import WeaviateIndexDict, Node | |||
| from llama_index.data_structs.node_v2 import DocumentRelationship | |||
| from llama_index.readers.weaviate.client import _class_name, NODE_SCHEMA, _logger | |||
| from llama_index.vector_stores import WeaviateVectorStore | |||
| from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult, VectorStoreQueryMode | |||
| from llama_index.readers.weaviate.utils import ( | |||
| parse_get_response, | |||
| validate_client, | |||
| ) | |||
| class WeaviateVectorStoreClient(BaseVectorStoreClient): | |||
| def __init__(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int): | |||
| self._client = self.init_from_config(endpoint, api_key, grpc_enabled, batch_size) | |||
| def init_from_config(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int): | |||
| auth_config = weaviate.auth.AuthApiKey(api_key=api_key) | |||
| weaviate.connect.connection.has_grpc = grpc_enabled | |||
| client = weaviate.Client( | |||
| url=endpoint, | |||
| auth_client_secret=auth_config, | |||
| timeout_config=(5, 60), | |||
| startup_period=None | |||
| ) | |||
| client.batch.configure( | |||
| # `batch_size` takes an `int` value to enable auto-batching | |||
| # (`None` is used for manual batching) | |||
| batch_size=batch_size, | |||
| # dynamically update the `batch_size` based on import speed | |||
| dynamic=True, | |||
| # `timeout_retries` takes an `int` value to retry on time outs | |||
| timeout_retries=3, | |||
| ) | |||
| return client | |||
| def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex: | |||
| index_struct = WeaviateIndexDict() | |||
| if self._client is None: | |||
| raise Exception("Vector client is not initialized.") | |||
| # {"class_prefix": "Gpt_index_xxx"} | |||
| class_prefix = config.get('class_prefix') | |||
| if not class_prefix: | |||
| raise Exception("class_prefix cannot be None.") | |||
| return GPTWeaviateEnhanceIndex( | |||
| service_context=service_context, | |||
| index_struct=index_struct, | |||
| vector_store=WeaviateWithSimilaritiesVectorStore( | |||
| weaviate_client=self._client, | |||
| class_prefix=class_prefix | |||
| ) | |||
| ) | |||
| def to_index_config(self, index_id: str) -> dict: | |||
| return {"class_prefix": index_id} | |||
| class WeaviateWithSimilaritiesVectorStore(WeaviateVectorStore, EnhanceVectorStore): | |||
| def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult: | |||
| """Query index for top k most similar nodes.""" | |||
| nodes = self.weaviate_query( | |||
| self._client, | |||
| self._class_prefix, | |||
| query, | |||
| ) | |||
| nodes = nodes[: query.similarity_top_k] | |||
| node_idxs = [str(i) for i in range(len(nodes))] | |||
| similarities = [] | |||
| for node in nodes: | |||
| similarities.append(node.extra_info['similarity']) | |||
| del node.extra_info['similarity'] | |||
| return VectorStoreQueryResult(nodes=nodes, ids=node_idxs, similarities=similarities) | |||
| def weaviate_query( | |||
| self, | |||
| client: Any, | |||
| class_prefix: str, | |||
| query_spec: VectorStoreQuery, | |||
| ) -> List[Node]: | |||
| """Convert to LlamaIndex list.""" | |||
| validate_client(client) | |||
| class_name = _class_name(class_prefix) | |||
| prop_names = [p["name"] for p in NODE_SCHEMA] | |||
| vector = query_spec.query_embedding | |||
| # build query | |||
| query = client.query.get(class_name, prop_names).with_additional(["id", "vector", "certainty"]) | |||
| if query_spec.mode == VectorStoreQueryMode.DEFAULT: | |||
| _logger.debug("Using vector search") | |||
| if vector is not None: | |||
| query = query.with_near_vector( | |||
| { | |||
| "vector": vector, | |||
| } | |||
| ) | |||
| elif query_spec.mode == VectorStoreQueryMode.HYBRID: | |||
| _logger.debug(f"Using hybrid search with alpha {query_spec.alpha}") | |||
| query = query.with_hybrid( | |||
| query=query_spec.query_str, | |||
| alpha=query_spec.alpha, | |||
| vector=vector, | |||
| ) | |||
| query = query.with_limit(query_spec.similarity_top_k) | |||
| _logger.debug(f"Using limit of {query_spec.similarity_top_k}") | |||
| # execute query | |||
| query_result = query.do() | |||
| # parse results | |||
| parsed_result = parse_get_response(query_result) | |||
| entries = parsed_result[class_name] | |||
| results = [self._to_node(entry) for entry in entries] | |||
| return results | |||
| def _to_node(self, entry: Dict) -> Node: | |||
| """Convert to Node.""" | |||
| extra_info_str = entry["extra_info"] | |||
| if extra_info_str == "": | |||
| extra_info = None | |||
| else: | |||
| extra_info = json.loads(extra_info_str) | |||
| if 'certainty' in entry['_additional']: | |||
| if extra_info: | |||
| extra_info['similarity'] = entry['_additional']['certainty'] | |||
| else: | |||
| extra_info = {'similarity': entry['_additional']['certainty']} | |||
| node_info_str = entry["node_info"] | |||
| if node_info_str == "": | |||
| node_info = None | |||
| else: | |||
| node_info = json.loads(node_info_str) | |||
| relationships_str = entry["relationships"] | |||
| relationships: Dict[DocumentRelationship, str] | |||
| if relationships_str == "": | |||
| relationships = field(default_factory=dict) | |||
| else: | |||
| relationships = { | |||
| DocumentRelationship(k): v for k, v in json.loads(relationships_str).items() | |||
| } | |||
| return Node( | |||
| text=entry["text"], | |||
| doc_id=entry["doc_id"], | |||
| embedding=entry["_additional"]["vector"], | |||
| extra_info=extra_info, | |||
| node_info=node_info, | |||
| relationships=relationships, | |||
| ) | |||
| def delete(self, doc_id: str, **delete_kwargs: Any) -> None: | |||
| """Delete a document. | |||
| Args: | |||
| doc_id (str): document id | |||
| """ | |||
| delete_document(self._client, doc_id, self._class_prefix) | |||
| def delete_node(self, node_id: str): | |||
| """ | |||
| Delete node from the index. | |||
| :param node_id: node id | |||
| """ | |||
| delete_node(self._client, node_id, self._class_prefix) | |||
| def exists_by_node_id(self, node_id: str) -> bool: | |||
| """ | |||
| Get node from the index by node id. | |||
| :param node_id: node id | |||
| """ | |||
| entry = get_by_node_id(self._client, node_id, self._class_prefix) | |||
| return True if entry else False | |||
| class GPTWeaviateEnhanceIndex(GPTWeaviateIndex, BaseGPTVectorStoreIndex): | |||
| pass | |||
| def delete_document(client: Any, ref_doc_id: str, class_prefix: str) -> None: | |||
| """Delete entry.""" | |||
| validate_client(client) | |||
| # make sure that each entry | |||
| class_name = _class_name(class_prefix) | |||
| where_filter = { | |||
| "path": ["ref_doc_id"], | |||
| "operator": "Equal", | |||
| "valueString": ref_doc_id, | |||
| } | |||
| query = ( | |||
| client.query.get(class_name).with_additional(["id"]).with_where(where_filter) | |||
| ) | |||
| query_result = query.do() | |||
| parsed_result = parse_get_response(query_result) | |||
| entries = parsed_result[class_name] | |||
| for entry in entries: | |||
| client.data_object.delete(entry["_additional"]["id"], class_name) | |||
| while len(entries) > 0: | |||
| query_result = query.do() | |||
| parsed_result = parse_get_response(query_result) | |||
| entries = parsed_result[class_name] | |||
| for entry in entries: | |||
| client.data_object.delete(entry["_additional"]["id"], class_name) | |||
| def delete_node(client: Any, node_id: str, class_prefix: str) -> None: | |||
| """Delete entry.""" | |||
| validate_client(client) | |||
| # make sure that each entry | |||
| class_name = _class_name(class_prefix) | |||
| where_filter = { | |||
| "path": ["doc_id"], | |||
| "operator": "Equal", | |||
| "valueString": node_id, | |||
| } | |||
| query = ( | |||
| client.query.get(class_name).with_additional(["id"]).with_where(where_filter) | |||
| ) | |||
| query_result = query.do() | |||
| parsed_result = parse_get_response(query_result) | |||
| entries = parsed_result[class_name] | |||
| for entry in entries: | |||
| client.data_object.delete(entry["_additional"]["id"], class_name) | |||
| def get_by_node_id(client: Any, node_id: str, class_prefix: str) -> Optional[Dict]: | |||
| """Delete entry.""" | |||
| validate_client(client) | |||
| # make sure that each entry | |||
| class_name = _class_name(class_prefix) | |||
| where_filter = { | |||
| "path": ["doc_id"], | |||
| "operator": "Equal", | |||
| "valueString": node_id, | |||
| } | |||
| query = ( | |||
| client.query.get(class_name).with_additional(["id"]).with_where(where_filter) | |||
| ) | |||
| query_result = query.do() | |||
| parsed_result = parse_get_response(query_result) | |||
| entries = parsed_result[class_name] | |||
| if len(entries) == 0: | |||
| return None | |||
| return entries[0] | |||
| @@ -1,7 +0,0 @@ | |||
| from core.vector_store.vector_store import VectorStore | |||
| vector_store = VectorStore() | |||
| def init_app(app): | |||
| vector_store.init_app(app) | |||
| @@ -3,6 +3,7 @@ import re | |||
| import subprocess | |||
| import uuid | |||
| from datetime import datetime | |||
| from hashlib import sha256 | |||
| from zoneinfo import available_timezones | |||
| import random | |||
| import string | |||
| @@ -147,3 +148,8 @@ def get_remote_ip(request): | |||
| return request.headers.getlist("X-Forwarded-For")[0] | |||
| else: | |||
| return request.remote_addr | |||
| def generate_text_hash(text: str) -> str: | |||
| hash_text = str(text) + 'None' | |||
| return sha256(hash_text.encode()).hexdigest() | |||
| @@ -38,8 +38,6 @@ class Account(UserMixin, db.Model): | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| _current_tenant: db.Model = None | |||
| @property | |||
| def current_tenant(self): | |||
| return self._current_tenant | |||
| @@ -66,6 +66,23 @@ class Dataset(db.Model): | |||
| def document_count(self): | |||
| return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar() | |||
| @property | |||
| def available_document_count(self): | |||
| return db.session.query(func.count(Document.id)).filter( | |||
| Document.dataset_id == self.id, | |||
| Document.indexing_status == 'completed', | |||
| Document.enabled == True, | |||
| Document.archived == False | |||
| ).scalar() | |||
| @property | |||
| def available_segment_count(self): | |||
| return db.session.query(func.count(DocumentSegment.id)).filter( | |||
| DocumentSegment.dataset_id == self.id, | |||
| DocumentSegment.status == 'completed', | |||
| DocumentSegment.enabled == True | |||
| ).scalar() | |||
| @property | |||
| def word_count(self): | |||
| return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \ | |||
| @@ -260,7 +277,7 @@ class Document(db.Model): | |||
| @property | |||
| def dataset(self): | |||
| return Dataset.query.get(self.dataset_id) | |||
| return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none() | |||
| @property | |||
| def segment_count(self): | |||
| @@ -395,7 +412,18 @@ class DatasetKeywordTable(db.Model): | |||
| @property | |||
| def keyword_table_dict(self): | |||
| return json.loads(self.keyword_table) if self.keyword_table else None | |||
| class SetDecoder(json.JSONDecoder): | |||
| def __init__(self, *args, **kwargs): | |||
| super().__init__(object_hook=self.object_hook, *args, **kwargs) | |||
| def object_hook(self, dct): | |||
| if isinstance(dct, dict): | |||
| for keyword, node_idxs in dct.items(): | |||
| if isinstance(node_idxs, list): | |||
| dct[keyword] = set(node_idxs) | |||
| return dct | |||
| return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None | |||
| class Embedding(db.Model): | |||
| @@ -2,6 +2,7 @@ coverage~=7.2.4 | |||
| beautifulsoup4==4.12.2 | |||
| flask~=2.3.2 | |||
| Flask-SQLAlchemy~=3.0.3 | |||
| SQLAlchemy~=1.4.28 | |||
| flask-login==0.6.2 | |||
| flask-migrate~=4.0.4 | |||
| flask-restful==0.3.9 | |||
| @@ -9,8 +10,7 @@ flask-session2==1.3.1 | |||
| flask-cors==3.0.10 | |||
| gunicorn~=20.1.0 | |||
| gevent~=22.10.2 | |||
| langchain==0.0.142 | |||
| llama-index==0.5.27 | |||
| langchain==0.0.209 | |||
| openai~=0.27.5 | |||
| psycopg2-binary~=2.9.6 | |||
| pycryptodome==3.17 | |||
| @@ -29,6 +29,7 @@ sentry-sdk[flask]~=1.21.1 | |||
| jieba==0.42.1 | |||
| celery==5.2.7 | |||
| redis~=4.5.4 | |||
| pypdf==3.8.1 | |||
| openpyxl==3.1.2 | |||
| chardet~=5.1.0 | |||
| chardet~=5.1.0 | |||
| docx2txt==0.8 | |||
| pypdfium2==4.16.0 | |||
| @@ -4,7 +4,6 @@ import uuid | |||
| from core.constant import llm_constant | |||
| from models.account import Account | |||
| from services.dataset_service import DatasetService | |||
| from services.errors.account import NoPermissionError | |||
| class AppModelConfigService: | |||
| @@ -7,7 +7,6 @@ from typing import Optional, List | |||
| from extensions.ext_redis import redis_client | |||
| from flask_login import current_user | |||
| from core.index.index_builder import IndexBuilder | |||
| from events.dataset_event import dataset_was_deleted | |||
| from events.document_event import document_was_deleted | |||
| from extensions.ext_database import db | |||
| @@ -386,8 +385,6 @@ class DocumentService: | |||
| dataset.indexing_technique = document_data["indexing_technique"] | |||
| if dataset.indexing_technique == 'high_quality': | |||
| IndexBuilder.get_default_service_context(dataset.tenant_id) | |||
| documents = [] | |||
| batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) | |||
| if 'original_document_id' in document_data and document_data["original_document_id"]: | |||
| @@ -3,47 +3,56 @@ import time | |||
| from typing import List | |||
| import numpy as np | |||
| from llama_index.data_structs.node_v2 import NodeWithScore | |||
| from llama_index.indices.query.schema import QueryBundle | |||
| from llama_index.indices.vector_store import GPTVectorStoreIndexQuery | |||
| from flask import current_app | |||
| from langchain.embeddings import OpenAIEmbeddings | |||
| from langchain.embeddings.base import Embeddings | |||
| from langchain.schema import Document | |||
| from sklearn.manifold import TSNE | |||
| from core.docstore.empty_docstore import EmptyDocumentStore | |||
| from core.index.vector_index import VectorIndex | |||
| from core.embedding.cached_embedding import CacheEmbedding | |||
| from core.index.vector_index.vector_index import VectorIndex | |||
| from core.llm.llm_builder import LLMBuilder | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.dataset import Dataset, DocumentSegment, DatasetQuery | |||
| from services.errors.index import IndexNotInitializedError | |||
| class HitTestingService: | |||
| @classmethod | |||
| def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict: | |||
| index = VectorIndex(dataset=dataset).query_index | |||
| if not index: | |||
| raise IndexNotInitializedError() | |||
| index_query = GPTVectorStoreIndexQuery( | |||
| index_struct=index.index_struct, | |||
| service_context=index.service_context, | |||
| vector_store=index.query_context.get('vector_store'), | |||
| docstore=EmptyDocumentStore(), | |||
| response_synthesizer=None, | |||
| similarity_top_k=limit | |||
| ) | |||
| if dataset.available_document_count == 0 or dataset.available_document_count == 0: | |||
| return { | |||
| "query": { | |||
| "content": query, | |||
| "tsne_position": {'x': 0, 'y': 0}, | |||
| }, | |||
| "records": [] | |||
| } | |||
| query_bundle = QueryBundle( | |||
| query_str=query, | |||
| custom_embedding_strs=[query], | |||
| model_credentials = LLMBuilder.get_model_credentials( | |||
| tenant_id=dataset.tenant_id, | |||
| model_provider=LLMBuilder.get_default_provider(dataset.tenant_id), | |||
| model_name='text-embedding-ada-002' | |||
| ) | |||
| query_bundle.embedding = index.service_context.embed_model.get_agg_embedding_from_queries( | |||
| query_bundle.embedding_strs | |||
| embeddings = CacheEmbedding(OpenAIEmbeddings( | |||
| **model_credentials | |||
| )) | |||
| vector_index = VectorIndex( | |||
| dataset=dataset, | |||
| config=current_app.config, | |||
| embeddings=embeddings | |||
| ) | |||
| start = time.perf_counter() | |||
| nodes = index_query.retrieve(query_bundle=query_bundle) | |||
| documents = vector_index.search( | |||
| query, | |||
| search_type='similarity_score_threshold', | |||
| search_kwargs={ | |||
| 'k': 10 | |||
| } | |||
| ) | |||
| end = time.perf_counter() | |||
| logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") | |||
| @@ -58,25 +67,24 @@ class HitTestingService: | |||
| db.session.add(dataset_query) | |||
| db.session.commit() | |||
| return cls.compact_retrieve_response(dataset, query_bundle, nodes) | |||
| return cls.compact_retrieve_response(dataset, embeddings, query, documents) | |||
| @classmethod | |||
| def compact_retrieve_response(cls, dataset: Dataset, query_bundle: QueryBundle, nodes: List[NodeWithScore]): | |||
| embeddings = [ | |||
| query_bundle.embedding | |||
| def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: List[Document]): | |||
| text_embeddings = [ | |||
| embeddings.embed_query(query) | |||
| ] | |||
| for node in nodes: | |||
| embeddings.append(node.node.embedding) | |||
| text_embeddings.extend(embeddings.embed_documents([document.page_content for document in documents])) | |||
| tsne_position_data = cls.get_tsne_positions_from_embeddings(embeddings) | |||
| tsne_position_data = cls.get_tsne_positions_from_embeddings(text_embeddings) | |||
| query_position = tsne_position_data.pop(0) | |||
| i = 0 | |||
| records = [] | |||
| for node in nodes: | |||
| index_node_id = node.node.doc_id | |||
| for document in documents: | |||
| index_node_id = document.metadata['doc_id'] | |||
| segment = db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.dataset_id == dataset.id, | |||
| @@ -91,7 +99,7 @@ class HitTestingService: | |||
| record = { | |||
| "segment": segment, | |||
| "score": node.score, | |||
| "score": document.metadata['score'], | |||
| "tsne_position": tsne_position_data[i] | |||
| } | |||
| @@ -101,7 +109,7 @@ class HitTestingService: | |||
| return { | |||
| "query": { | |||
| "content": query_bundle.query_str, | |||
| "content": query, | |||
| "tsne_position": query_position, | |||
| }, | |||
| "records": records | |||
| @@ -4,96 +4,81 @@ import time | |||
| import click | |||
| from celery import shared_task | |||
| from llama_index.data_structs import Node | |||
| from llama_index.data_structs.node_v2 import DocumentRelationship | |||
| from langchain.schema import Document | |||
| from werkzeug.exceptions import NotFound | |||
| from core.index.keyword_table_index import KeywordTableIndex | |||
| from core.index.vector_index import VectorIndex | |||
| from core.index.index import IndexBuilder | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import DocumentSegment, Document | |||
| from models.dataset import DocumentSegment | |||
| from models.dataset import Document as DatasetDocument | |||
| @shared_task | |||
| def add_document_to_index_task(document_id: str): | |||
| def add_document_to_index_task(dataset_document_id: str): | |||
| """ | |||
| Async Add document to index | |||
| :param document_id: | |||
| Usage: add_document_to_index.delay(document_id) | |||
| """ | |||
| logging.info(click.style('Start add document to index: {}'.format(document_id), fg='green')) | |||
| logging.info(click.style('Start add document to index: {}'.format(dataset_document_id), fg='green')) | |||
| start_at = time.perf_counter() | |||
| document = db.session.query(Document).filter(Document.id == document_id).first() | |||
| if not document: | |||
| dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document_id).first() | |||
| if not dataset_document: | |||
| raise NotFound('Document not found') | |||
| if document.indexing_status != 'completed': | |||
| if dataset_document.indexing_status != 'completed': | |||
| return | |||
| indexing_cache_key = 'document_{}_indexing'.format(document.id) | |||
| indexing_cache_key = 'document_{}_indexing'.format(dataset_document.id) | |||
| try: | |||
| segments = db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.document_id == document.id, | |||
| DocumentSegment.document_id == dataset_document.id, | |||
| DocumentSegment.enabled == True | |||
| ) \ | |||
| .order_by(DocumentSegment.position.asc()).all() | |||
| nodes = [] | |||
| previous_node = None | |||
| documents = [] | |||
| for segment in segments: | |||
| relationships = { | |||
| DocumentRelationship.SOURCE: document.id | |||
| } | |||
| if previous_node: | |||
| relationships[DocumentRelationship.PREVIOUS] = previous_node.doc_id | |||
| previous_node.relationships[DocumentRelationship.NEXT] = segment.index_node_id | |||
| node = Node( | |||
| doc_id=segment.index_node_id, | |||
| doc_hash=segment.index_node_hash, | |||
| text=segment.content, | |||
| extra_info=None, | |||
| node_info=None, | |||
| relationships=relationships | |||
| document = Document( | |||
| page_content=segment.content, | |||
| metadata={ | |||
| "doc_id": segment.index_node_id, | |||
| "doc_hash": segment.index_node_hash, | |||
| "document_id": segment.document_id, | |||
| "dataset_id": segment.dataset_id, | |||
| } | |||
| ) | |||
| previous_node = node | |||
| documents.append(document) | |||
| nodes.append(node) | |||
| dataset = document.dataset | |||
| dataset = dataset_document.dataset | |||
| if not dataset: | |||
| raise Exception('Document has no dataset') | |||
| vector_index = VectorIndex(dataset=dataset) | |||
| keyword_table_index = KeywordTableIndex(dataset=dataset) | |||
| # save vector index | |||
| if dataset.indexing_technique == "high_quality": | |||
| vector_index.add_nodes( | |||
| nodes=nodes, | |||
| duplicate_check=True | |||
| ) | |||
| index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| if index: | |||
| index.add_texts(documents) | |||
| # save keyword index | |||
| keyword_table_index.add_nodes(nodes) | |||
| index = IndexBuilder.get_index(dataset, 'economy') | |||
| if index: | |||
| index.add_texts(documents) | |||
| end_at = time.perf_counter() | |||
| logging.info( | |||
| click.style('Document added to index: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) | |||
| click.style('Document added to index: {} latency: {}'.format(dataset_document.id, end_at - start_at), fg='green')) | |||
| except Exception as e: | |||
| logging.exception("add document to index failed") | |||
| document.enabled = False | |||
| document.disabled_at = datetime.datetime.utcnow() | |||
| document.status = 'error' | |||
| document.error = str(e) | |||
| dataset_document.enabled = False | |||
| dataset_document.disabled_at = datetime.datetime.utcnow() | |||
| dataset_document.status = 'error' | |||
| dataset_document.error = str(e) | |||
| db.session.commit() | |||
| finally: | |||
| redis_client.delete(indexing_cache_key) | |||
| @@ -4,12 +4,10 @@ import time | |||
| import click | |||
| from celery import shared_task | |||
| from llama_index.data_structs import Node | |||
| from llama_index.data_structs.node_v2 import DocumentRelationship | |||
| from langchain.schema import Document | |||
| from werkzeug.exceptions import NotFound | |||
| from core.index.keyword_table_index import KeywordTableIndex | |||
| from core.index.vector_index import VectorIndex | |||
| from core.index.index import IndexBuilder | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import DocumentSegment | |||
| @@ -36,44 +34,41 @@ def add_segment_to_index_task(segment_id: str): | |||
| indexing_cache_key = 'segment_{}_indexing'.format(segment.id) | |||
| try: | |||
| relationships = { | |||
| DocumentRelationship.SOURCE: segment.document_id, | |||
| } | |||
| previous_segment = segment.previous_segment | |||
| if previous_segment: | |||
| relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id | |||
| next_segment = segment.next_segment | |||
| if next_segment: | |||
| relationships[DocumentRelationship.NEXT] = next_segment.index_node_id | |||
| node = Node( | |||
| doc_id=segment.index_node_id, | |||
| doc_hash=segment.index_node_hash, | |||
| text=segment.content, | |||
| extra_info=None, | |||
| node_info=None, | |||
| relationships=relationships | |||
| document = Document( | |||
| page_content=segment.content, | |||
| metadata={ | |||
| "doc_id": segment.index_node_id, | |||
| "doc_hash": segment.index_node_hash, | |||
| "document_id": segment.document_id, | |||
| "dataset_id": segment.dataset_id, | |||
| } | |||
| ) | |||
| dataset = segment.dataset | |||
| if not dataset: | |||
| raise Exception('Segment has no dataset') | |||
| logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) | |||
| return | |||
| vector_index = VectorIndex(dataset=dataset) | |||
| keyword_table_index = KeywordTableIndex(dataset=dataset) | |||
| dataset_document = segment.document | |||
| if not dataset_document: | |||
| logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) | |||
| return | |||
| if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': | |||
| logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) | |||
| return | |||
| # save vector index | |||
| if dataset.indexing_technique == "high_quality": | |||
| vector_index.add_nodes( | |||
| nodes=[node], | |||
| duplicate_check=True | |||
| ) | |||
| index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| if index: | |||
| index.add_texts([document], duplicate_check=True) | |||
| # save keyword index | |||
| keyword_table_index.add_nodes([node]) | |||
| index = IndexBuilder.get_index(dataset, 'economy') | |||
| if index: | |||
| index.add_texts([document]) | |||
| end_at = time.perf_counter() | |||
| logging.info(click.style('Segment added to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) | |||
| @@ -4,8 +4,7 @@ import time | |||
| import click | |||
| from celery import shared_task | |||
| from core.index.keyword_table_index import KeywordTableIndex | |||
| from core.index.vector_index import VectorIndex | |||
| from core.index.index import IndexBuilder | |||
| from extensions.ext_database import db | |||
| from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \ | |||
| AppDatasetJoin | |||
| @@ -33,29 +32,24 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, | |||
| index_struct=index_struct | |||
| ) | |||
| vector_index = VectorIndex(dataset=dataset) | |||
| keyword_table_index = KeywordTableIndex(dataset=dataset) | |||
| documents = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() | |||
| index_doc_ids = [document.id for document in documents] | |||
| segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| # delete from vector index | |||
| if dataset.indexing_technique == "high_quality": | |||
| for index_doc_id in index_doc_ids: | |||
| try: | |||
| vector_index.del_doc(index_doc_id) | |||
| except Exception: | |||
| logging.exception("Delete doc index failed when dataset deleted.") | |||
| continue | |||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| kw_index = IndexBuilder.get_index(dataset, 'economy') | |||
| # delete from keyword index | |||
| if index_node_ids: | |||
| # delete from vector index | |||
| if vector_index: | |||
| try: | |||
| keyword_table_index.del_nodes(index_node_ids) | |||
| vector_index.delete() | |||
| except Exception: | |||
| logging.exception("Delete nodes index failed when dataset deleted.") | |||
| logging.exception("Delete doc index failed when dataset deleted.") | |||
| # delete from keyword index | |||
| try: | |||
| kw_index.delete() | |||
| except Exception: | |||
| logging.exception("Delete nodes index failed when dataset deleted.") | |||
| for document in documents: | |||
| db.session.delete(document) | |||
| @@ -63,7 +57,6 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, | |||
| for segment in segments: | |||
| db.session.delete(segment) | |||
| db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == dataset_id).delete() | |||
| db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete() | |||
| db.session.query(DatasetQuery).filter(DatasetQuery.dataset_id == dataset_id).delete() | |||
| db.session.query(AppDatasetJoin).filter(AppDatasetJoin.dataset_id == dataset_id).delete() | |||
| @@ -4,8 +4,7 @@ import time | |||
| import click | |||
| from celery import shared_task | |||
| from core.index.keyword_table_index import KeywordTableIndex | |||
| from core.index.vector_index import VectorIndex | |||
| from core.index.index import IndexBuilder | |||
| from extensions.ext_database import db | |||
| from models.dataset import DocumentSegment, Dataset | |||
| @@ -28,21 +27,23 @@ def clean_document_task(document_id: str, dataset_id: str): | |||
| if not dataset: | |||
| raise Exception('Document has no dataset') | |||
| vector_index = VectorIndex(dataset=dataset) | |||
| keyword_table_index = KeywordTableIndex(dataset=dataset) | |||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| kw_index = IndexBuilder.get_index(dataset, 'economy') | |||
| segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| # delete from vector index | |||
| vector_index.del_nodes(index_node_ids) | |||
| if vector_index: | |||
| vector_index.delete_by_document_id(document_id) | |||
| # delete from keyword index | |||
| if index_node_ids: | |||
| keyword_table_index.del_nodes(index_node_ids) | |||
| kw_index.delete_by_ids(index_node_ids) | |||
| for segment in segments: | |||
| db.session.delete(segment) | |||
| db.session.commit() | |||
| end_at = time.perf_counter() | |||
| logging.info( | |||
| @@ -5,8 +5,7 @@ from typing import List | |||
| import click | |||
| from celery import shared_task | |||
| from core.index.keyword_table_index import KeywordTableIndex | |||
| from core.index.vector_index import VectorIndex | |||
| from core.index.index import IndexBuilder | |||
| from extensions.ext_database import db | |||
| from models.dataset import DocumentSegment, Dataset, Document | |||
| @@ -29,22 +28,24 @@ def clean_notion_document_task(document_ids: List[str], dataset_id: str): | |||
| if not dataset: | |||
| raise Exception('Document has no dataset') | |||
| vector_index = VectorIndex(dataset=dataset) | |||
| keyword_table_index = KeywordTableIndex(dataset=dataset) | |||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| kw_index = IndexBuilder.get_index(dataset, 'economy') | |||
| for document_id in document_ids: | |||
| document = db.session.query(Document).filter( | |||
| Document.id == document_id | |||
| ).first() | |||
| db.session.delete(document) | |||
| segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| # delete from vector index | |||
| vector_index.del_nodes(index_node_ids) | |||
| if vector_index: | |||
| vector_index.delete_by_document_id(document_id) | |||
| # delete from keyword index | |||
| if index_node_ids: | |||
| keyword_table_index.del_nodes(index_node_ids) | |||
| kw_index.delete_by_ids(index_node_ids) | |||
| for segment in segments: | |||
| db.session.delete(segment) | |||
| @@ -3,10 +3,12 @@ import time | |||
| import click | |||
| from celery import shared_task | |||
| from llama_index.data_structs.node_v2 import DocumentRelationship, Node | |||
| from core.index.vector_index import VectorIndex | |||
| from langchain.schema import Document | |||
| from core.index.index import IndexBuilder | |||
| from extensions.ext_database import db | |||
| from models.dataset import DocumentSegment, Document, Dataset | |||
| from models.dataset import DocumentSegment, Dataset | |||
| from models.dataset import Document as DatasetDocument | |||
| @shared_task | |||
| @@ -24,49 +26,47 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): | |||
| dataset = Dataset.query.filter_by( | |||
| id=dataset_id | |||
| ).first() | |||
| if not dataset: | |||
| raise Exception('Dataset not found') | |||
| documents = Document.query.filter_by(dataset_id=dataset_id).all() | |||
| if documents: | |||
| vector_index = VectorIndex(dataset=dataset) | |||
| for document in documents: | |||
| # delete from vector index | |||
| if action == "remove": | |||
| vector_index.del_doc(document.id) | |||
| elif action == "add": | |||
| if action == "remove": | |||
| index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True) | |||
| index.delete() | |||
| elif action == "add": | |||
| dataset_documents = db.session.query(DatasetDocument).filter( | |||
| DatasetDocument.dataset_id == dataset_id, | |||
| DatasetDocument.indexing_status == 'completed', | |||
| DatasetDocument.enabled == True, | |||
| DatasetDocument.archived == False, | |||
| ).all() | |||
| if dataset_documents: | |||
| # save vector index | |||
| index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True) | |||
| for dataset_document in dataset_documents: | |||
| # delete from vector index | |||
| segments = db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.document_id == document.id, | |||
| DocumentSegment.document_id == dataset_document.id, | |||
| DocumentSegment.enabled == True | |||
| ) .order_by(DocumentSegment.position.asc()).all() | |||
| nodes = [] | |||
| previous_node = None | |||
| documents = [] | |||
| for segment in segments: | |||
| relationships = { | |||
| DocumentRelationship.SOURCE: document.id | |||
| } | |||
| if previous_node: | |||
| relationships[DocumentRelationship.PREVIOUS] = previous_node.doc_id | |||
| previous_node.relationships[DocumentRelationship.NEXT] = segment.index_node_id | |||
| node = Node( | |||
| doc_id=segment.index_node_id, | |||
| doc_hash=segment.index_node_hash, | |||
| text=segment.content, | |||
| extra_info=None, | |||
| node_info=None, | |||
| relationships=relationships | |||
| document = Document( | |||
| page_content=segment.content, | |||
| metadata={ | |||
| "doc_id": segment.index_node_id, | |||
| "doc_hash": segment.index_node_hash, | |||
| "document_id": segment.document_id, | |||
| "dataset_id": segment.dataset_id, | |||
| } | |||
| ) | |||
| previous_node = node | |||
| nodes.append(node) | |||
| documents.append(document) | |||
| # save vector index | |||
| vector_index.add_nodes( | |||
| nodes=nodes, | |||
| duplicate_check=True | |||
| ) | |||
| index.add_texts(documents) | |||
| end_at = time.perf_counter() | |||
| logging.info( | |||
| @@ -6,11 +6,9 @@ import click | |||
| from celery import shared_task | |||
| from werkzeug.exceptions import NotFound | |||
| from core.data_source.notion import NotionPageReader | |||
| from core.index.keyword_table_index import KeywordTableIndex | |||
| from core.index.vector_index import VectorIndex | |||
| from core.data_loader.loader.notion import NotionLoader | |||
| from core.index.index import IndexBuilder | |||
| from core.indexing_runner import IndexingRunner, DocumentIsPausedException | |||
| from core.llm.error import ProviderTokenNotInitError | |||
| from extensions.ext_database import db | |||
| from models.dataset import Document, Dataset, DocumentSegment | |||
| from models.source import DataSourceBinding | |||
| @@ -43,6 +41,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): | |||
| raise ValueError("no notion page found") | |||
| workspace_id = data_source_info['notion_workspace_id'] | |||
| page_id = data_source_info['notion_page_id'] | |||
| page_type = data_source_info['type'] | |||
| page_edited_time = data_source_info['last_edited_time'] | |||
| data_source_binding = DataSourceBinding.query.filter( | |||
| db.and_( | |||
| @@ -54,8 +53,16 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): | |||
| ).first() | |||
| if not data_source_binding: | |||
| raise ValueError('Data source binding not found.') | |||
| reader = NotionPageReader(integration_token=data_source_binding.access_token) | |||
| last_edited_time = reader.get_page_last_edited_time(page_id) | |||
| loader = NotionLoader( | |||
| notion_access_token=data_source_binding.access_token, | |||
| notion_workspace_id=workspace_id, | |||
| notion_obj_id=page_id, | |||
| notion_page_type=page_type | |||
| ) | |||
| last_edited_time = loader.get_notion_last_edited_time() | |||
| # check the page is updated | |||
| if last_edited_time != page_edited_time: | |||
| document.indexing_status = 'parsing' | |||
| @@ -68,18 +75,19 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): | |||
| if not dataset: | |||
| raise Exception('Dataset not found') | |||
| vector_index = VectorIndex(dataset=dataset) | |||
| keyword_table_index = KeywordTableIndex(dataset=dataset) | |||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| kw_index = IndexBuilder.get_index(dataset, 'economy') | |||
| segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| # delete from vector index | |||
| vector_index.del_nodes(index_node_ids) | |||
| if vector_index: | |||
| vector_index.delete_by_document_id(document_id) | |||
| # delete from keyword index | |||
| if index_node_ids: | |||
| keyword_table_index.del_nodes(index_node_ids) | |||
| kw_index.delete_by_ids(index_node_ids) | |||
| for segment in segments: | |||
| db.session.delete(segment) | |||
| @@ -89,21 +97,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): | |||
| click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) | |||
| except Exception: | |||
| logging.exception("Cleaned document when document update data source or process rule failed") | |||
| try: | |||
| indexing_runner = IndexingRunner() | |||
| indexing_runner.run([document]) | |||
| end_at = time.perf_counter() | |||
| logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) | |||
| except DocumentIsPausedException: | |||
| logging.info(click.style('Document update paused, document id: {}'.format(document.id), fg='yellow')) | |||
| except ProviderTokenNotInitError as e: | |||
| document.indexing_status = 'error' | |||
| document.error = str(e.description) | |||
| document.stopped_at = datetime.datetime.utcnow() | |||
| db.session.commit() | |||
| except Exception as e: | |||
| logging.exception("consume update document failed") | |||
| document.indexing_status = 'error' | |||
| document.error = str(e) | |||
| document.stopped_at = datetime.datetime.utcnow() | |||
| db.session.commit() | |||
| except DocumentIsPausedException as ex: | |||
| logging.info(click.style(str(ex), fg='yellow')) | |||
| except Exception: | |||
| pass | |||
| @@ -7,7 +7,6 @@ from celery import shared_task | |||
| from werkzeug.exceptions import NotFound | |||
| from core.indexing_runner import IndexingRunner, DocumentIsPausedException | |||
| from core.llm.error import ProviderTokenNotInitError | |||
| from extensions.ext_database import db | |||
| from models.dataset import Document | |||
| @@ -22,9 +21,9 @@ def document_indexing_task(dataset_id: str, document_ids: list): | |||
| Usage: document_indexing_task.delay(dataset_id, document_id) | |||
| """ | |||
| documents = [] | |||
| start_at = time.perf_counter() | |||
| for document_id in document_ids: | |||
| logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) | |||
| start_at = time.perf_counter() | |||
| document = db.session.query(Document).filter( | |||
| Document.id == document_id, | |||
| @@ -44,17 +43,8 @@ def document_indexing_task(dataset_id: str, document_ids: list): | |||
| indexing_runner = IndexingRunner() | |||
| indexing_runner.run(documents) | |||
| end_at = time.perf_counter() | |||
| logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) | |||
| except DocumentIsPausedException: | |||
| logging.info(click.style('Document paused, document id: {}'.format(document.id), fg='yellow')) | |||
| except ProviderTokenNotInitError as e: | |||
| document.indexing_status = 'error' | |||
| document.error = str(e.description) | |||
| document.stopped_at = datetime.datetime.utcnow() | |||
| db.session.commit() | |||
| except Exception as e: | |||
| logging.exception("consume document failed") | |||
| document.indexing_status = 'error' | |||
| document.error = str(e) | |||
| document.stopped_at = datetime.datetime.utcnow() | |||
| db.session.commit() | |||
| logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) | |||
| except DocumentIsPausedException as ex: | |||
| logging.info(click.style(str(ex), fg='yellow')) | |||
| except Exception: | |||
| pass | |||
| @@ -6,10 +6,8 @@ import click | |||
| from celery import shared_task | |||
| from werkzeug.exceptions import NotFound | |||
| from core.index.keyword_table_index import KeywordTableIndex | |||
| from core.index.vector_index import VectorIndex | |||
| from core.index.index import IndexBuilder | |||
| from core.indexing_runner import IndexingRunner, DocumentIsPausedException | |||
| from core.llm.error import ProviderTokenNotInitError | |||
| from extensions.ext_database import db | |||
| from models.dataset import Document, Dataset, DocumentSegment | |||
| @@ -44,18 +42,19 @@ def document_indexing_update_task(dataset_id: str, document_id: str): | |||
| if not dataset: | |||
| raise Exception('Dataset not found') | |||
| vector_index = VectorIndex(dataset=dataset) | |||
| keyword_table_index = KeywordTableIndex(dataset=dataset) | |||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| kw_index = IndexBuilder.get_index(dataset, 'economy') | |||
| segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| # delete from vector index | |||
| vector_index.del_nodes(index_node_ids) | |||
| if vector_index: | |||
| vector_index.delete_by_ids(index_node_ids) | |||
| # delete from keyword index | |||
| if index_node_ids: | |||
| keyword_table_index.del_nodes(index_node_ids) | |||
| kw_index.delete_by_ids(index_node_ids) | |||
| for segment in segments: | |||
| db.session.delete(segment) | |||
| @@ -65,21 +64,13 @@ def document_indexing_update_task(dataset_id: str, document_id: str): | |||
| click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) | |||
| except Exception: | |||
| logging.exception("Cleaned document when document update data source or process rule failed") | |||
| try: | |||
| indexing_runner = IndexingRunner() | |||
| indexing_runner.run([document]) | |||
| end_at = time.perf_counter() | |||
| logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) | |||
| except DocumentIsPausedException: | |||
| logging.info(click.style('Document update paused, document id: {}'.format(document.id), fg='yellow')) | |||
| except ProviderTokenNotInitError as e: | |||
| document.indexing_status = 'error' | |||
| document.error = str(e.description) | |||
| document.stopped_at = datetime.datetime.utcnow() | |||
| db.session.commit() | |||
| except Exception as e: | |||
| logging.exception("consume update document failed") | |||
| document.indexing_status = 'error' | |||
| document.error = str(e) | |||
| document.stopped_at = datetime.datetime.utcnow() | |||
| db.session.commit() | |||
| except DocumentIsPausedException as ex: | |||
| logging.info(click.style(str(ex), fg='yellow')) | |||
| except Exception: | |||
| pass | |||
| @@ -1,4 +1,3 @@ | |||
| import datetime | |||
| import logging | |||
| import time | |||
| @@ -41,11 +40,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): | |||
| indexing_runner.run_in_indexing_status(document) | |||
| end_at = time.perf_counter() | |||
| logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) | |||
| except DocumentIsPausedException: | |||
| logging.info(click.style('Document paused, document id: {}'.format(document.id), fg='yellow')) | |||
| except Exception as e: | |||
| logging.exception("consume document failed") | |||
| document.indexing_status = 'error' | |||
| document.error = str(e) | |||
| document.stopped_at = datetime.datetime.utcnow() | |||
| db.session.commit() | |||
| except DocumentIsPausedException as ex: | |||
| logging.info(click.style(str(ex), fg='yellow')) | |||
| except Exception: | |||
| pass | |||
| @@ -5,8 +5,7 @@ import click | |||
| from celery import shared_task | |||
| from werkzeug.exceptions import NotFound | |||
| from core.index.keyword_table_index import KeywordTableIndex | |||
| from core.index.vector_index import VectorIndex | |||
| from core.index.index import IndexBuilder | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import DocumentSegment, Document | |||
| @@ -38,17 +37,17 @@ def remove_document_from_index_task(document_id: str): | |||
| if not dataset: | |||
| raise Exception('Document has no dataset') | |||
| vector_index = VectorIndex(dataset=dataset) | |||
| keyword_table_index = KeywordTableIndex(dataset=dataset) | |||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| kw_index = IndexBuilder.get_index(dataset, 'economy') | |||
| # delete from vector index | |||
| vector_index.del_doc(document.id) | |||
| vector_index.delete_by_document_id(document.id) | |||
| # delete from keyword index | |||
| segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all() | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| if index_node_ids: | |||
| keyword_table_index.del_nodes(index_node_ids) | |||
| kw_index.delete_by_ids(index_node_ids) | |||
| end_at = time.perf_counter() | |||
| logging.info( | |||
| @@ -5,8 +5,7 @@ import click | |||
| from celery import shared_task | |||
| from werkzeug.exceptions import NotFound | |||
| from core.index.keyword_table_index import KeywordTableIndex | |||
| from core.index.vector_index import VectorIndex | |||
| from core.index.index import IndexBuilder | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import DocumentSegment | |||
| @@ -36,17 +35,28 @@ def remove_segment_from_index_task(segment_id: str): | |||
| dataset = segment.dataset | |||
| if not dataset: | |||
| raise Exception('Segment has no dataset') | |||
| logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) | |||
| return | |||
| vector_index = VectorIndex(dataset=dataset) | |||
| keyword_table_index = KeywordTableIndex(dataset=dataset) | |||
| dataset_document = segment.document | |||
| if not dataset_document: | |||
| logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) | |||
| return | |||
| if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': | |||
| logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) | |||
| return | |||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| kw_index = IndexBuilder.get_index(dataset, 'economy') | |||
| # delete from vector index | |||
| if dataset.indexing_technique == "high_quality": | |||
| vector_index.del_nodes([segment.index_node_id]) | |||
| if vector_index: | |||
| vector_index.delete_by_ids([segment.index_node_id]) | |||
| # delete from keyword index | |||
| keyword_table_index.del_nodes([segment.index_node_id]) | |||
| kw_index.delete_by_ids([segment.index_node_id]) | |||
| end_at = time.perf_counter() | |||
| logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) | |||