Co-authored-by: jyong <718720800@qq.com>tags/0.3.6
| import flask_login | import flask_login | ||||
| from flask_cors import CORS | from flask_cors import CORS | ||||
| from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_vector_store, ext_migrate, \ | |||||
| from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \ | |||||
| ext_database, ext_storage | ext_database, ext_storage | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_login import login_manager | from extensions.ext_login import login_manager | ||||
| ext_database.init_app(app) | ext_database.init_app(app) | ||||
| ext_migrate.init(app, db) | ext_migrate.init(app, db) | ||||
| ext_redis.init_app(app) | ext_redis.init_app(app) | ||||
| ext_vector_store.init_app(app) | |||||
| ext_storage.init_app(app) | ext_storage.init_app(app) | ||||
| ext_celery.init_app(app) | ext_celery.init_app(app) | ||||
| ext_session.init_app(app) | ext_session.init_app(app) |
| import datetime | import datetime | ||||
| import logging | |||||
| import random | import random | ||||
| import string | import string | ||||
| import click | import click | ||||
| from flask import current_app | from flask import current_app | ||||
| from werkzeug.exceptions import NotFound | |||||
| from core.index.index import IndexBuilder | |||||
| from libs.password import password_pattern, valid_password, hash_password | from libs.password import password_pattern, valid_password, hash_password | ||||
| from libs.helper import email as email_validate | from libs.helper import email as email_validate | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from libs.rsa import generate_key_pair | from libs.rsa import generate_key_pair | ||||
| from models.account import InvitationCode, Tenant | from models.account import InvitationCode, Tenant | ||||
| from models.dataset import Dataset | |||||
| from models.model import Account | from models.model import Account | ||||
| import secrets | import secrets | ||||
| import base64 | import base64 | ||||
| return result | return result | ||||
| @click.command('recreate-all-dataset-indexes', help='Recreate all dataset indexes.') | |||||
| def recreate_all_dataset_indexes(): | |||||
| click.echo(click.style('Start recreate all dataset indexes.', fg='green')) | |||||
| recreate_count = 0 | |||||
| page = 1 | |||||
| while True: | |||||
| try: | |||||
| datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality')\ | |||||
| .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50) | |||||
| except NotFound: | |||||
| break | |||||
| page += 1 | |||||
| for dataset in datasets: | |||||
| try: | |||||
| click.echo('Recreating dataset index: {}'.format(dataset.id)) | |||||
| index = IndexBuilder.get_index(dataset, 'high_quality') | |||||
| if index and index._is_origin(): | |||||
| index.recreate_dataset(dataset) | |||||
| recreate_count += 1 | |||||
| else: | |||||
| click.echo('passed.') | |||||
| except Exception as e: | |||||
| click.echo(click.style('Recreate dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red')) | |||||
| continue | |||||
| click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green')) | |||||
| def register_commands(app): | def register_commands(app): | ||||
| app.cli.add_command(reset_password) | app.cli.add_command(reset_password) | ||||
| app.cli.add_command(reset_email) | app.cli.add_command(reset_email) | ||||
| app.cli.add_command(generate_invitation_codes) | app.cli.add_command(generate_invitation_codes) | ||||
| app.cli.add_command(reset_encrypt_key_pair) | app.cli.add_command(reset_encrypt_key_pair) | ||||
| app.cli.add_command(recreate_all_dataset_indexes) |
| # For temp use only | # For temp use only | ||||
| # set default LLM provider, default is 'openai', support `azure_openai` | # set default LLM provider, default is 'openai', support `azure_openai` | ||||
| self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER') | self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER') | ||||
| # notion import setting | # notion import setting | ||||
| self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID') | self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID') | ||||
| self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET') | self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET') | ||||
| self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE') | self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE') | ||||
| self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET') | self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET') | ||||
| self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN') | |||||
| class CloudEditionConfig(Config): | class CloudEditionConfig(Config): |
| from controllers.console import api | from controllers.console import api | ||||
| from controllers.console.setup import setup_required | from controllers.console.setup import setup_required | ||||
| from controllers.console.wraps import account_initialization_required | from controllers.console.wraps import account_initialization_required | ||||
| from core.data_source.notion import NotionPageReader | |||||
| from core.data_loader.loader.notion import NotionLoader | |||||
| from core.indexing_runner import IndexingRunner | from core.indexing_runner import IndexingRunner | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from libs.helper import TimestampField | from libs.helper import TimestampField | ||||
| from libs.oauth_data_source import NotionOAuth | |||||
| from models.dataset import Document | from models.dataset import Document | ||||
| from models.source import DataSourceBinding | from models.source import DataSourceBinding | ||||
| from services.dataset_service import DatasetService, DocumentService | from services.dataset_service import DatasetService, DocumentService | ||||
| ).first() | ).first() | ||||
| if not data_source_binding: | if not data_source_binding: | ||||
| raise NotFound('Data source binding not found.') | raise NotFound('Data source binding not found.') | ||||
| reader = NotionPageReader(integration_token=data_source_binding.access_token) | |||||
| if page_type == 'page': | |||||
| page_content = reader.read_page(page_id) | |||||
| elif page_type == 'database': | |||||
| page_content = reader.query_database_data(page_id) | |||||
| else: | |||||
| page_content = "" | |||||
| loader = NotionLoader( | |||||
| notion_access_token=data_source_binding.access_token, | |||||
| notion_workspace_id=workspace_id, | |||||
| notion_obj_id=page_id, | |||||
| notion_page_type=page_type | |||||
| ) | |||||
| text_docs = loader.load() | |||||
| return { | return { | ||||
| 'content': page_content | |||||
| 'content': "\n".join([doc.page_content for doc in text_docs]) | |||||
| }, 200 | }, 200 | ||||
| @setup_required | @setup_required |
| UnsupportedFileTypeError | UnsupportedFileTypeError | ||||
| from controllers.console.setup import setup_required | from controllers.console.setup import setup_required | ||||
| from controllers.console.wraps import account_initialization_required | from controllers.console.wraps import account_initialization_required | ||||
| from core.index.readers.html_parser import HTMLParser | |||||
| from core.index.readers.pdf_parser import PDFParser | |||||
| from core.index.readers.xlsx_parser import XLSXParser | |||||
| from core.data_loader.file_extractor import FileExtractor | |||||
| from extensions.ext_storage import storage | from extensions.ext_storage import storage | ||||
| from libs.helper import TimestampField | from libs.helper import TimestampField | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| if extension not in ALLOWED_EXTENSIONS: | if extension not in ALLOWED_EXTENSIONS: | ||||
| raise UnsupportedFileTypeError() | raise UnsupportedFileTypeError() | ||||
| with tempfile.TemporaryDirectory() as temp_dir: | |||||
| suffix = Path(upload_file.key).suffix | |||||
| filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" | |||||
| storage.download(upload_file.key, filepath) | |||||
| if extension == 'pdf': | |||||
| parser = PDFParser({'upload_file': upload_file}) | |||||
| text = parser.parse_file(Path(filepath)) | |||||
| elif extension in ['html', 'htm']: | |||||
| # Use BeautifulSoup to extract text | |||||
| parser = HTMLParser() | |||||
| text = parser.parse_file(Path(filepath)) | |||||
| elif extension == 'xlsx': | |||||
| parser = XLSXParser() | |||||
| text = parser.parse_file(filepath) | |||||
| else: | |||||
| # ['txt', 'markdown', 'md'] | |||||
| with open(filepath, "rb") as fp: | |||||
| data = fp.read() | |||||
| encoding = chardet.detect(data)['encoding'] | |||||
| if encoding: | |||||
| text = data.decode(encoding=encoding).strip() if data else '' | |||||
| else: | |||||
| text = data.decode(encoding='utf-8').strip() if data else '' | |||||
| text = FileExtractor.load(upload_file, return_text=True) | |||||
| text = text[0:PREVIEW_WORDS_LIMIT] if text else '' | text = text[0:PREVIEW_WORDS_LIMIT] if text else '' | ||||
| return {'content': text} | return {'content': text} | ||||
| 'current_version': args.get('current_version') | 'current_version': args.get('current_version') | ||||
| }) | }) | ||||
| except Exception as error: | except Exception as error: | ||||
| logging.exception("Check update error.") | |||||
| raise InternalServerError() | |||||
| logging.warning("Check update version error: {}.".format(str(error))) | |||||
| return { | |||||
| 'version': args.get('current_version'), | |||||
| 'release_date': '', | |||||
| 'release_notes': '', | |||||
| 'can_auto_update': False | |||||
| } | |||||
| content = json.loads(response.content) | content = json.loads(response.content) | ||||
| return { | return { |
| import langchain | import langchain | ||||
| from flask import Flask | from flask import Flask | ||||
| from jieba.analyse import default_tfidf | |||||
| from langchain import set_handler | |||||
| from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING | from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING | ||||
| from llama_index import IndexStructType, QueryMode | |||||
| from llama_index.indices.registry import INDEX_STRUT_TYPE_TO_QUERY_MAP | |||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler | from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler | ||||
| from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex | |||||
| from core.index.keyword_table.stopwords import STOPWORDS | |||||
| from core.prompt.prompt_template import OneLineFormatter | from core.prompt.prompt_template import OneLineFormatter | ||||
| from core.vector_store.vector_store import VectorStore | |||||
| from core.vector_store.vector_store_index_query import EnhanceGPTVectorStoreIndexQuery | |||||
| class HostedOpenAICredential(BaseModel): | class HostedOpenAICredential(BaseModel): | ||||
| def init_app(app: Flask): | def init_app(app: Flask): | ||||
| formatter = OneLineFormatter() | formatter = OneLineFormatter() | ||||
| DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format | DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format | ||||
| INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.KEYWORD_TABLE] = GPTJIEBAKeywordTableIndex.get_query_map() | |||||
| INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.WEAVIATE] = { | |||||
| QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery, | |||||
| QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery, | |||||
| } | |||||
| INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.QDRANT] = { | |||||
| QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery, | |||||
| QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery, | |||||
| } | |||||
| default_tfidf.stop_words = STOPWORDS | |||||
| if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': | if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': | ||||
| langchain.verbose = True | langchain.verbose = True | ||||
| set_handler(DifyStdOutCallbackHandler()) | |||||
| if app.config.get("OPENAI_API_KEY"): | if app.config.get("OPENAI_API_KEY"): | ||||
| hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY")) | hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY")) |
| from langchain import LLMChain | from langchain import LLMChain | ||||
| from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent | from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent | ||||
| from langchain.callbacks import CallbackManager | |||||
| from langchain.callbacks.manager import CallbackManager | |||||
| from langchain.memory.chat_memory import BaseChatMemory | from langchain.memory.chat_memory import BaseChatMemory | ||||
| from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler | from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler | ||||
| def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory], | def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory], | ||||
| dataset_tool_callback_handler: DatasetToolCallbackHandler, | dataset_tool_callback_handler: DatasetToolCallbackHandler, | ||||
| agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler): | agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler): | ||||
| llm_callback_manager = CallbackManager([agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]) | |||||
| llm = LLMBuilder.to_llm( | llm = LLMBuilder.to_llm( | ||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| model_name=agent_loop_gather_callback_handler.model_name, | model_name=agent_loop_gather_callback_handler.model_name, | ||||
| temperature=0, | temperature=0, | ||||
| max_tokens=1024, | max_tokens=1024, | ||||
| callback_manager=llm_callback_manager | |||||
| callbacks=[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()] | |||||
| ) | ) | ||||
| tool_callback_manager = CallbackManager([ | |||||
| agent_loop_gather_callback_handler, | |||||
| dataset_tool_callback_handler, | |||||
| DifyStdOutCallbackHandler() | |||||
| ]) | |||||
| for tool in tools: | for tool in tools: | ||||
| tool.callback_manager = tool_callback_manager | |||||
| tool.callbacks = [ | |||||
| agent_loop_gather_callback_handler, | |||||
| dataset_tool_callback_handler, | |||||
| DifyStdOutCallbackHandler() | |||||
| ] | |||||
| prompt = cls.build_agent_prompt_template( | prompt = cls.build_agent_prompt_template( | ||||
| tools=tools, | tools=tools, | ||||
| tools=tools, | tools=tools, | ||||
| agent=agent, | agent=agent, | ||||
| memory=memory, | memory=memory, | ||||
| callback_manager=agent_callback_manager, | |||||
| callbacks=agent_callback_manager, | |||||
| max_iterations=6, | max_iterations=6, | ||||
| early_stopping_method="generate", | early_stopping_method="generate", | ||||
| # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit | # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit |
| class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | ||||
| """Callback Handler that prints to std out.""" | """Callback Handler that prints to std out.""" | ||||
| raise_error: bool = True | |||||
| def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None: | def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None: | ||||
| """Initialize callback handler.""" | """Initialize callback handler.""" | ||||
| self._current_loop.completion = response.generations[0][0].text | self._current_loop.completion = response.generations[0][0].text | ||||
| self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens'] | self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens'] | ||||
| def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |||||
| """Do nothing.""" | |||||
| pass | |||||
| def on_llm_error( | def on_llm_error( | ||||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | ||||
| ) -> None: | ) -> None: | ||||
| self._agent_loops = [] | self._agent_loops = [] | ||||
| self._current_loop = None | self._current_loop = None | ||||
| def on_chain_start( | |||||
| self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any | |||||
| ) -> None: | |||||
| """Print out that we are entering a chain.""" | |||||
| pass | |||||
| def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | |||||
| """Print out that we finished a chain.""" | |||||
| pass | |||||
| def on_chain_error( | |||||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |||||
| ) -> None: | |||||
| logging.error(error) | |||||
| def on_tool_start( | def on_tool_start( | ||||
| self, | self, | ||||
| serialized: Dict[str, Any], | serialized: Dict[str, Any], | ||||
| self._agent_loops = [] | self._agent_loops = [] | ||||
| self._current_loop = None | self._current_loop = None | ||||
| def on_text( | |||||
| self, | |||||
| text: str, | |||||
| color: Optional[str] = None, | |||||
| end: str = "", | |||||
| **kwargs: Optional[str], | |||||
| ) -> None: | |||||
| """Run on additional input from chains and agents.""" | |||||
| pass | |||||
| def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: | def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: | ||||
| """Run on agent end.""" | """Run on agent end.""" | ||||
| # Final Answer | # Final Answer |
| from typing import Any, Dict, List, Union, Optional | from typing import Any, Dict, List, Union, Optional | ||||
| from langchain.callbacks.base import BaseCallbackHandler | from langchain.callbacks.base import BaseCallbackHandler | ||||
| from langchain.schema import AgentAction, AgentFinish, LLMResult | |||||
| from core.callback_handler.entity.dataset_query import DatasetQueryObj | from core.callback_handler.entity.dataset_query import DatasetQueryObj | ||||
| from core.conversation_message_task import ConversationMessageTask | from core.conversation_message_task import ConversationMessageTask | ||||
| class DatasetToolCallbackHandler(BaseCallbackHandler): | class DatasetToolCallbackHandler(BaseCallbackHandler): | ||||
| """Callback Handler that prints to std out.""" | """Callback Handler that prints to std out.""" | ||||
| raise_error: bool = True | |||||
| def __init__(self, conversation_message_task: ConversationMessageTask) -> None: | def __init__(self, conversation_message_task: ConversationMessageTask) -> None: | ||||
| """Initialize callback handler.""" | """Initialize callback handler.""" | ||||
| ) -> None: | ) -> None: | ||||
| """Do nothing.""" | """Do nothing.""" | ||||
| logging.error(error) | logging.error(error) | ||||
| def on_chain_start( | |||||
| self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any | |||||
| ) -> None: | |||||
| pass | |||||
| def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | |||||
| pass | |||||
| def on_chain_error( | |||||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |||||
| ) -> None: | |||||
| pass | |||||
| def on_llm_start( | |||||
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |||||
| ) -> None: | |||||
| pass | |||||
| def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |||||
| pass | |||||
| def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |||||
| """Do nothing.""" | |||||
| pass | |||||
| def on_llm_error( | |||||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |||||
| ) -> None: | |||||
| logging.error(error) | |||||
| def on_agent_action( | |||||
| self, action: AgentAction, color: Optional[str] = None, **kwargs: Any | |||||
| ) -> Any: | |||||
| pass | |||||
| def on_text( | |||||
| self, | |||||
| text: str, | |||||
| color: Optional[str] = None, | |||||
| end: str = "", | |||||
| **kwargs: Optional[str], | |||||
| ) -> None: | |||||
| """Run on additional input from chains and agents.""" | |||||
| pass | |||||
| def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: | |||||
| """Run on agent end.""" | |||||
| pass |
| from llama_index import Response | |||||
| from typing import List | |||||
| from langchain.schema import Document | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import DocumentSegment | from models.dataset import DocumentSegment | ||||
| class IndexToolCallbackHandler: | |||||
| def __init__(self) -> None: | |||||
| self._response = None | |||||
| @property | |||||
| def response(self) -> Response: | |||||
| return self._response | |||||
| def on_tool_end(self, response: Response) -> None: | |||||
| """Handle tool end.""" | |||||
| self._response = response | |||||
| class DatasetIndexToolCallbackHandler(IndexToolCallbackHandler): | |||||
| class DatasetIndexToolCallbackHandler: | |||||
| """Callback handler for dataset tool.""" | """Callback handler for dataset tool.""" | ||||
| def __init__(self, dataset_id: str) -> None: | def __init__(self, dataset_id: str) -> None: | ||||
| super().__init__() | |||||
| self.dataset_id = dataset_id | self.dataset_id = dataset_id | ||||
| def on_tool_end(self, response: Response) -> None: | |||||
| def on_tool_end(self, documents: List[Document]) -> None: | |||||
| """Handle tool end.""" | """Handle tool end.""" | ||||
| for node in response.source_nodes: | |||||
| index_node_id = node.node.doc_id | |||||
| for document in documents: | |||||
| doc_id = document.metadata['doc_id'] | |||||
| # add hit count to document segment | # add hit count to document segment | ||||
| db.session.query(DocumentSegment).filter( | db.session.query(DocumentSegment).filter( | ||||
| DocumentSegment.dataset_id == self.dataset_id, | DocumentSegment.dataset_id == self.dataset_id, | ||||
| DocumentSegment.index_node_id == index_node_id | |||||
| DocumentSegment.index_node_id == doc_id | |||||
| ).update( | ).update( | ||||
| {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, | {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, | ||||
| synchronize_session=False | synchronize_session=False |
| from typing import Any, Dict, List, Union, Optional | from typing import Any, Dict, List, Union, Optional | ||||
| from langchain.callbacks.base import BaseCallbackHandler | from langchain.callbacks.base import BaseCallbackHandler | ||||
| from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage | |||||
| from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage, BaseMessage | |||||
| from core.callback_handler.entity.llm_message import LLMMessage | from core.callback_handler.entity.llm_message import LLMMessage | ||||
| from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException | from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException | ||||
| class LLMCallbackHandler(BaseCallbackHandler): | class LLMCallbackHandler(BaseCallbackHandler): | ||||
| raise_error: bool = True | |||||
| def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI], | def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI], | ||||
| conversation_message_task: ConversationMessageTask): | conversation_message_task: ConversationMessageTask): | ||||
| """Whether to call verbose callbacks even if verbose is False.""" | """Whether to call verbose callbacks even if verbose is False.""" | ||||
| return True | return True | ||||
| def on_chat_model_start( | |||||
| self, | |||||
| serialized: Dict[str, Any], | |||||
| messages: List[List[BaseMessage]], | |||||
| **kwargs: Any | |||||
| ) -> Any: | |||||
| self.start_at = time.perf_counter() | |||||
| real_prompts = [] | |||||
| for message in messages[0]: | |||||
| if message.type == 'human': | |||||
| role = 'user' | |||||
| elif message.type == 'ai': | |||||
| role = 'assistant' | |||||
| else: | |||||
| role = 'system' | |||||
| real_prompts.append({ | |||||
| "role": role, | |||||
| "text": message.content | |||||
| }) | |||||
| self.llm_message.prompt = real_prompts | |||||
| self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0]) | |||||
| def on_llm_start( | def on_llm_start( | ||||
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | ||||
| ) -> None: | ) -> None: | ||||
| self.start_at = time.perf_counter() | self.start_at = time.perf_counter() | ||||
| if 'Chat' in serialized['name']: | |||||
| real_prompts = [] | |||||
| messages = [] | |||||
| for prompt in prompts: | |||||
| role, content = prompt.split(': ', maxsplit=1) | |||||
| if role == 'human': | |||||
| role = 'user' | |||||
| message = HumanMessage(content=content) | |||||
| elif role == 'ai': | |||||
| role = 'assistant' | |||||
| message = AIMessage(content=content) | |||||
| else: | |||||
| message = SystemMessage(content=content) | |||||
| real_prompt = { | |||||
| "role": role, | |||||
| "text": content | |||||
| } | |||||
| real_prompts.append(real_prompt) | |||||
| messages.append(message) | |||||
| self.llm_message.prompt = real_prompts | |||||
| self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages) | |||||
| else: | |||||
| self.llm_message.prompt = [{ | |||||
| "role": 'user', | |||||
| "text": prompts[0] | |||||
| }] | |||||
| self.llm_message.prompt = [{ | |||||
| "role": 'user', | |||||
| "text": prompts[0] | |||||
| }] | |||||
| self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0]) | |||||
| self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0]) | |||||
| def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | ||||
| end_at = time.perf_counter() | end_at = time.perf_counter() | ||||
| self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True) | self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True) | ||||
| else: | else: | ||||
| logging.error(error) | logging.error(error) | ||||
| def on_chain_start( | |||||
| self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any | |||||
| ) -> None: | |||||
| pass | |||||
| def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | |||||
| pass | |||||
| def on_chain_error( | |||||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |||||
| ) -> None: | |||||
| pass | |||||
| def on_tool_start( | |||||
| self, | |||||
| serialized: Dict[str, Any], | |||||
| input_str: str, | |||||
| **kwargs: Any, | |||||
| ) -> None: | |||||
| pass | |||||
| def on_agent_action( | |||||
| self, action: AgentAction, color: Optional[str] = None, **kwargs: Any | |||||
| ) -> Any: | |||||
| pass | |||||
| def on_tool_end( | |||||
| self, | |||||
| output: str, | |||||
| color: Optional[str] = None, | |||||
| observation_prefix: Optional[str] = None, | |||||
| llm_prefix: Optional[str] = None, | |||||
| **kwargs: Any, | |||||
| ) -> None: | |||||
| pass | |||||
| def on_tool_error( | |||||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |||||
| ) -> None: | |||||
| pass | |||||
| def on_text( | |||||
| self, | |||||
| text: str, | |||||
| color: Optional[str] = None, | |||||
| end: str = "", | |||||
| **kwargs: Optional[str], | |||||
| ) -> None: | |||||
| pass | |||||
| def on_agent_finish( | |||||
| self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any | |||||
| ) -> None: | |||||
| pass |
| import logging | import logging | ||||
| import time | import time | ||||
| from typing import Any, Dict, List, Union, Optional | |||||
| from typing import Any, Dict, Union | |||||
| from langchain.callbacks.base import BaseCallbackHandler | from langchain.callbacks.base import BaseCallbackHandler | ||||
| from langchain.schema import AgentAction, AgentFinish, LLMResult | |||||
| from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler | from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler | ||||
| from core.callback_handler.entity.chain_result import ChainResult | from core.callback_handler.entity.chain_result import ChainResult | ||||
| class MainChainGatherCallbackHandler(BaseCallbackHandler): | class MainChainGatherCallbackHandler(BaseCallbackHandler): | ||||
| """Callback Handler that prints to std out.""" | """Callback Handler that prints to std out.""" | ||||
| raise_error: bool = True | |||||
| def __init__(self, conversation_message_task: ConversationMessageTask) -> None: | def __init__(self, conversation_message_task: ConversationMessageTask) -> None: | ||||
| """Initialize callback handler.""" | """Initialize callback handler.""" | ||||
| ) -> None: | ) -> None: | ||||
| """Print out that we are entering a chain.""" | """Print out that we are entering a chain.""" | ||||
| if not self._current_chain_result: | if not self._current_chain_result: | ||||
| self._current_chain_result = ChainResult( | |||||
| type=serialized['name'], | |||||
| prompt=inputs, | |||||
| started_at=time.perf_counter() | |||||
| ) | |||||
| self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result) | |||||
| self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message | |||||
| chain_type = serialized['id'][-1] | |||||
| if chain_type: | |||||
| self._current_chain_result = ChainResult( | |||||
| type=chain_type, | |||||
| prompt=inputs, | |||||
| started_at=time.perf_counter() | |||||
| ) | |||||
| self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result) | |||||
| self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message | |||||
| def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | ||||
| """Print out that we finished a chain.""" | """Print out that we finished a chain.""" | ||||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | ||||
| ) -> None: | ) -> None: | ||||
| logging.error(error) | logging.error(error) | ||||
| self.clear_chain_results() | |||||
| def on_llm_start( | |||||
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |||||
| ) -> None: | |||||
| pass | |||||
| def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |||||
| pass | |||||
| def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |||||
| """Do nothing.""" | |||||
| pass | |||||
| def on_llm_error( | |||||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |||||
| ) -> None: | |||||
| logging.error(error) | |||||
| def on_tool_start( | |||||
| self, | |||||
| serialized: Dict[str, Any], | |||||
| input_str: str, | |||||
| **kwargs: Any, | |||||
| ) -> None: | |||||
| pass | |||||
| def on_agent_action( | |||||
| self, action: AgentAction, color: Optional[str] = None, **kwargs: Any | |||||
| ) -> Any: | |||||
| pass | |||||
| def on_tool_end( | |||||
| self, | |||||
| output: str, | |||||
| color: Optional[str] = None, | |||||
| observation_prefix: Optional[str] = None, | |||||
| llm_prefix: Optional[str] = None, | |||||
| **kwargs: Any, | |||||
| ) -> None: | |||||
| pass | |||||
| def on_tool_error( | |||||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |||||
| ) -> None: | |||||
| """Do nothing.""" | |||||
| logging.error(error) | |||||
| def on_text( | |||||
| self, | |||||
| text: str, | |||||
| color: Optional[str] = None, | |||||
| end: str = "", | |||||
| **kwargs: Optional[str], | |||||
| ) -> None: | |||||
| """Run on additional input from chains and agents.""" | |||||
| pass | |||||
| def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: | |||||
| """Run on agent end.""" | |||||
| pass | |||||
| self.clear_chain_results() |
| import os | |||||
| import sys | import sys | ||||
| from typing import Any, Dict, List, Optional, Union | from typing import Any, Dict, List, Optional, Union | ||||
| from langchain.callbacks.base import BaseCallbackHandler | from langchain.callbacks.base import BaseCallbackHandler | ||||
| from langchain.input import print_text | from langchain.input import print_text | ||||
| from langchain.schema import AgentAction, AgentFinish, LLMResult | |||||
| from langchain.schema import AgentAction, AgentFinish, LLMResult, BaseMessage | |||||
| class DifyStdOutCallbackHandler(BaseCallbackHandler): | class DifyStdOutCallbackHandler(BaseCallbackHandler): | ||||
| """Initialize callback handler.""" | """Initialize callback handler.""" | ||||
| self.color = color | self.color = color | ||||
| def on_chat_model_start( | |||||
| self, | |||||
| serialized: Dict[str, Any], | |||||
| messages: List[List[BaseMessage]], | |||||
| **kwargs: Any | |||||
| ) -> Any: | |||||
| print_text("\n[on_chat_model_start]\n", color='blue') | |||||
| for sub_messages in messages: | |||||
| for sub_message in sub_messages: | |||||
| print_text(str(sub_message) + "\n", color='blue') | |||||
| def on_llm_start( | def on_llm_start( | ||||
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | ||||
| ) -> None: | ) -> None: | ||||
| """Print out the prompts.""" | """Print out the prompts.""" | ||||
| print_text("\n[on_llm_start]\n", color='blue') | print_text("\n[on_llm_start]\n", color='blue') | ||||
| if 'Chat' in serialized['name']: | |||||
| for prompt in prompts: | |||||
| print_text(prompt + "\n", color='blue') | |||||
| else: | |||||
| print_text(prompts[0] + "\n", color='blue') | |||||
| print_text(prompts[0] + "\n", color='blue') | |||||
| def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | ||||
| """Do nothing.""" | """Do nothing.""" | ||||
| self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any | self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any | ||||
| ) -> None: | ) -> None: | ||||
| """Print out that we are entering a chain.""" | """Print out that we are entering a chain.""" | ||||
| class_name = serialized["name"] | |||||
| print_text("\n[on_chain_start]\nChain: " + class_name + "\nInputs: " + str(inputs) + "\n", color='pink') | |||||
| chain_type = serialized['id'][-1] | |||||
| print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink') | |||||
| def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | ||||
| """Print out that we finished a chain.""" | """Print out that we finished a chain.""" | ||||
| """Run on agent end.""" | """Run on agent end.""" | ||||
| print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n") | print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n") | ||||
| @property | |||||
| def ignore_llm(self) -> bool: | |||||
| """Whether to ignore LLM callbacks.""" | |||||
| return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' | |||||
| @property | |||||
| def ignore_chain(self) -> bool: | |||||
| """Whether to ignore chain callbacks.""" | |||||
| return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' | |||||
| @property | |||||
| def ignore_agent(self) -> bool: | |||||
| """Whether to ignore agent callbacks.""" | |||||
| return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' | |||||
| @property | |||||
| def ignore_chat_model(self) -> bool: | |||||
| """Whether to ignore chat model callbacks.""" | |||||
| return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' | |||||
| class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler): | class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler): | ||||
| """Callback handler for streaming. Only works with LLMs that support streaming.""" | """Callback handler for streaming. Only works with LLMs that support streaming.""" |
| from typing import Optional | from typing import Optional | ||||
| from langchain.callbacks import CallbackManager | |||||
| from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler | from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler | ||||
| from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain | from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain | ||||
| from core.chain.tool_chain import ToolChain | from core.chain.tool_chain import ToolChain | ||||
| tool=tool, | tool=tool, | ||||
| input_key=kwargs.get('input_key', 'input'), | input_key=kwargs.get('input_key', 'input'), | ||||
| output_key=kwargs.get('output_key', 'tool_output'), | output_key=kwargs.get('output_key', 'tool_output'), | ||||
| callback_manager=CallbackManager([DifyStdOutCallbackHandler()]) | |||||
| callbacks=[DifyStdOutCallbackHandler()] | |||||
| ) | ) | ||||
| @classmethod | @classmethod | ||||
| sensitive_words=sensitive_words.split(","), | sensitive_words=sensitive_words.split(","), | ||||
| canned_response=tool_config.get("canned_response", ''), | canned_response=tool_config.get("canned_response", ''), | ||||
| output_key="sensitive_word_avoidance_output", | output_key="sensitive_word_avoidance_output", | ||||
| callback_manager=CallbackManager([DifyStdOutCallbackHandler()]), | |||||
| callbacks=[DifyStdOutCallbackHandler()], | |||||
| **kwargs | **kwargs | ||||
| ) | ) | ||||
| """Base classes for LLM-powered router chains.""" | """Base classes for LLM-powered router chains.""" | ||||
| from __future__ import annotations | from __future__ import annotations | ||||
| import json | |||||
| from typing import Any, Dict, List, Optional, Type, cast, NamedTuple | from typing import Any, Dict, List, Optional, Type, cast, NamedTuple | ||||
| from langchain.base_language import BaseLanguageModel | |||||
| from langchain.callbacks.manager import CallbackManagerForChainRun | |||||
| from langchain.chains.base import Chain | from langchain.chains.base import Chain | ||||
| from pydantic import root_validator | from pydantic import root_validator | ||||
| from langchain.chains import LLMChain | from langchain.chains import LLMChain | ||||
| from langchain.prompts import BasePromptTemplate | from langchain.prompts import BasePromptTemplate | ||||
| from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel | |||||
| from langchain.schema import BaseOutputParser, OutputParserException | |||||
| from libs.json_in_md_parser import parse_and_check_json_markdown | from libs.json_in_md_parser import parse_and_check_json_markdown | ||||
| raise ValueError | raise ValueError | ||||
| def _call( | def _call( | ||||
| self, | |||||
| inputs: Dict[str, Any] | |||||
| self, | |||||
| inputs: Dict[str, Any], | |||||
| run_manager: Optional[CallbackManagerForChainRun] = None, | |||||
| ) -> Dict[str, Any]: | ) -> Dict[str, Any]: | ||||
| output = cast( | output = cast( | ||||
| Dict[str, Any], | Dict[str, Any], |
| from typing import Optional, List | |||||
| from typing import Optional, List, cast | |||||
| from langchain.callbacks import SharedCallbackManager, CallbackManager | |||||
| from langchain.chains import SequentialChain | from langchain.chains import SequentialChain | ||||
| from langchain.chains.base import Chain | from langchain.chains.base import Chain | ||||
| from langchain.memory.chat_memory import BaseChatMemory | from langchain.memory.chat_memory import BaseChatMemory | ||||
| from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler | |||||
| from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler | from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler | ||||
| from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler | from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler | ||||
| from core.chain.chain_builder import ChainBuilder | from core.chain.chain_builder import ChainBuilder | ||||
| class MainChainBuilder: | class MainChainBuilder: | ||||
| @classmethod | @classmethod | ||||
| def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory], | def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory], | ||||
| rest_tokens: int, | |||||
| conversation_message_task: ConversationMessageTask): | conversation_message_task: ConversationMessageTask): | ||||
| first_input_key = "input" | first_input_key = "input" | ||||
| final_output_key = "output" | final_output_key = "output" | ||||
| tool_chains, chains_output_key = cls.get_agent_chains( | tool_chains, chains_output_key = cls.get_agent_chains( | ||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| agent_mode=agent_mode, | agent_mode=agent_mode, | ||||
| rest_tokens=rest_tokens, | |||||
| memory=memory, | memory=memory, | ||||
| conversation_message_task=conversation_message_task | conversation_message_task=conversation_message_task | ||||
| ) | ) | ||||
| return None | return None | ||||
| for chain in chains: | for chain in chains: | ||||
| # do not add handler into singleton callback manager | |||||
| if not isinstance(chain.callback_manager, SharedCallbackManager): | |||||
| chain.callback_manager.add_handler(chain_callback_handler) | |||||
| chain = cast(Chain, chain) | |||||
| chain.callbacks.append(chain_callback_handler) | |||||
| # build main chain | # build main chain | ||||
| overall_chain = SequentialChain( | overall_chain = SequentialChain( | ||||
| return overall_chain | return overall_chain | ||||
| @classmethod | @classmethod | ||||
| def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory], | |||||
| def get_agent_chains(cls, tenant_id: str, agent_mode: dict, | |||||
| rest_tokens: int, | |||||
| memory: Optional[BaseChatMemory], | |||||
| conversation_message_task: ConversationMessageTask): | conversation_message_task: ConversationMessageTask): | ||||
| # agent mode | # agent mode | ||||
| chains = [] | chains = [] | ||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| datasets=datasets, | datasets=datasets, | ||||
| conversation_message_task=conversation_message_task, | conversation_message_task=conversation_message_task, | ||||
| callback_manager=CallbackManager([DifyStdOutCallbackHandler()]) | |||||
| rest_tokens=rest_tokens, | |||||
| callbacks=[DifyStdOutCallbackHandler()] | |||||
| ) | ) | ||||
| chains.append(multi_dataset_router_chain) | chains.append(multi_dataset_router_chain) | ||||
| import math | |||||
| from typing import Mapping, List, Dict, Any, Optional | from typing import Mapping, List, Dict, Any, Optional | ||||
| from langchain import LLMChain, PromptTemplate, ConversationChain | |||||
| from langchain.callbacks import CallbackManager | |||||
| from langchain import PromptTemplate | |||||
| from langchain.callbacks.manager import CallbackManagerForChainRun | |||||
| from langchain.chains.base import Chain | from langchain.chains.base import Chain | ||||
| from langchain.schema import BaseLanguageModel | |||||
| from pydantic import Extra | from pydantic import Extra | ||||
| from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler | from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler | ||||
| from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser | from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser | ||||
| from core.conversation_message_task import ConversationMessageTask | from core.conversation_message_task import ConversationMessageTask | ||||
| from core.llm.llm_builder import LLMBuilder | from core.llm.llm_builder import LLMBuilder | ||||
| from core.tool.dataset_tool_builder import DatasetToolBuilder | |||||
| from core.tool.llama_index_tool import EnhanceLlamaIndexTool | |||||
| from models.dataset import Dataset | |||||
| from core.tool.dataset_index_tool import DatasetTool | |||||
| from models.dataset import Dataset, DatasetProcessRule | |||||
| DEFAULT_K = 2 | |||||
| CONTEXT_TOKENS_PERCENT = 0.3 | |||||
| MULTI_PROMPT_ROUTER_TEMPLATE = """ | MULTI_PROMPT_ROUTER_TEMPLATE = """ | ||||
| Given a raw text input to a language model select the model prompt best suited for \ | Given a raw text input to a language model select the model prompt best suited for \ | ||||
| the input. You will be given the names of the available prompts and a description of \ | the input. You will be given the names of the available prompts and a description of \ | ||||
| router_chain: LLMRouterChain | router_chain: LLMRouterChain | ||||
| """Chain for deciding a destination chain and the input to it.""" | """Chain for deciding a destination chain and the input to it.""" | ||||
| dataset_tools: Mapping[str, EnhanceLlamaIndexTool] | |||||
| dataset_tools: Mapping[str, DatasetTool] | |||||
| """Map of name to candidate chains that inputs can be routed to.""" | """Map of name to candidate chains that inputs can be routed to.""" | ||||
| class Config: | class Config: | ||||
| tenant_id: str, | tenant_id: str, | ||||
| datasets: List[Dataset], | datasets: List[Dataset], | ||||
| conversation_message_task: ConversationMessageTask, | conversation_message_task: ConversationMessageTask, | ||||
| rest_tokens: int, | |||||
| **kwargs: Any, | **kwargs: Any, | ||||
| ): | ): | ||||
| """Convenience constructor for instantiating from destination prompts.""" | """Convenience constructor for instantiating from destination prompts.""" | ||||
| llm_callback_manager = CallbackManager([DifyStdOutCallbackHandler()]) | |||||
| llm = LLMBuilder.to_llm( | llm = LLMBuilder.to_llm( | ||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| model_name='gpt-3.5-turbo', | model_name='gpt-3.5-turbo', | ||||
| temperature=0, | temperature=0, | ||||
| max_tokens=1024, | max_tokens=1024, | ||||
| callback_manager=llm_callback_manager | |||||
| callbacks=[DifyStdOutCallbackHandler()] | |||||
| ) | ) | ||||
| destinations = ["{}: {}".format(d.id, d.description.replace('\n', ' ') if d.description | |||||
| destinations = ["[[{}]]: {}".format(d.id, d.description.replace('\n', ' ') if d.description | |||||
| else ('useful for when you want to answer queries about the ' + d.name)) | else ('useful for when you want to answer queries about the ' + d.name)) | ||||
| for d in datasets] | for d in datasets] | ||||
| destinations_str = "\n".join(destinations) | destinations_str = "\n".join(destinations) | ||||
| router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format( | router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format( | ||||
| destinations=destinations_str | destinations=destinations_str | ||||
| ) | ) | ||||
| router_prompt = PromptTemplate( | router_prompt = PromptTemplate( | ||||
| template=router_template, | template=router_template, | ||||
| input_variables=["input"], | input_variables=["input"], | ||||
| output_parser=RouterOutputParser(), | output_parser=RouterOutputParser(), | ||||
| ) | ) | ||||
| router_chain = LLMRouterChain.from_llm(llm, router_prompt) | router_chain = LLMRouterChain.from_llm(llm, router_prompt) | ||||
| dataset_tools = {} | dataset_tools = {} | ||||
| for dataset in datasets: | for dataset in datasets: | ||||
| dataset_tool = DatasetToolBuilder.build_dataset_tool( | |||||
| # fulfill description when it is empty | |||||
| if dataset.available_document_count == 0 or dataset.available_document_count == 0: | |||||
| continue | |||||
| description = dataset.description | |||||
| if not description: | |||||
| description = 'useful for when you want to answer queries about the ' + dataset.name | |||||
| k = cls._dynamic_calc_retrieve_k(dataset, rest_tokens) | |||||
| if k == 0: | |||||
| continue | |||||
| dataset_tool = DatasetTool( | |||||
| name=f"dataset-{dataset.id}", | |||||
| description=description, | |||||
| k=k, | |||||
| dataset=dataset, | dataset=dataset, | ||||
| response_mode='no_synthesizer', # "compact" | |||||
| callback_handler=DatasetToolCallbackHandler(conversation_message_task) | |||||
| callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()] | |||||
| ) | ) | ||||
| if dataset_tool: | |||||
| dataset_tools[dataset.id] = dataset_tool | |||||
| dataset_tools[str(dataset.id)] = dataset_tool | |||||
| return cls( | return cls( | ||||
| router_chain=router_chain, | router_chain=router_chain, | ||||
| **kwargs, | **kwargs, | ||||
| ) | ) | ||||
| @classmethod | |||||
| def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int: | |||||
| processing_rule = dataset.latest_process_rule | |||||
| if not processing_rule: | |||||
| return DEFAULT_K | |||||
| if processing_rule.mode == "custom": | |||||
| rules = processing_rule.rules_dict | |||||
| if not rules: | |||||
| return DEFAULT_K | |||||
| segmentation = rules["segmentation"] | |||||
| segment_max_tokens = segmentation["max_tokens"] | |||||
| else: | |||||
| segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'] | |||||
| # when rest_tokens is less than default context tokens | |||||
| if rest_tokens < segment_max_tokens * DEFAULT_K: | |||||
| return rest_tokens // segment_max_tokens | |||||
| context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT) | |||||
| # when context_limit_tokens is less than default context tokens, use default_k | |||||
| if context_limit_tokens <= segment_max_tokens * DEFAULT_K: | |||||
| return DEFAULT_K | |||||
| # Expand the k value when there's still some room left in the 30% rest tokens space | |||||
| return context_limit_tokens // segment_max_tokens | |||||
| def _call( | def _call( | ||||
| self, | self, | ||||
| inputs: Dict[str, Any] | |||||
| inputs: Dict[str, Any], | |||||
| run_manager: Optional[CallbackManagerForChainRun] = None, | |||||
| ) -> Dict[str, Any]: | ) -> Dict[str, Any]: | ||||
| if len(self.dataset_tools) == 0: | if len(self.dataset_tools) == 0: | ||||
| return {"text": ''} | return {"text": ''} |
| from typing import List, Dict | |||||
| from typing import List, Dict, Optional, Any | |||||
| from langchain.callbacks.manager import CallbackManagerForChainRun | |||||
| from langchain.chains.base import Chain | from langchain.chains.base import Chain | ||||
| return self.canned_response | return self.canned_response | ||||
| return text | return text | ||||
| def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: | |||||
| def _call( | |||||
| self, | |||||
| inputs: Dict[str, Any], | |||||
| run_manager: Optional[CallbackManagerForChainRun] = None, | |||||
| ) -> Dict[str, Any]: | |||||
| text = inputs[self.input_key] | text = inputs[self.input_key] | ||||
| output = self._check_sensitive_word(text) | output = self._check_sensitive_word(text) | ||||
| return {self.output_key: output} | return {self.output_key: output} |
| from typing import List, Dict | |||||
| from typing import List, Dict, Optional, Any | |||||
| from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun | |||||
| from langchain.chains.base import Chain | from langchain.chains.base import Chain | ||||
| from langchain.tools import BaseTool | from langchain.tools import BaseTool | ||||
| """ | """ | ||||
| return [self.output_key] | return [self.output_key] | ||||
| def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: | |||||
| def _call( | |||||
| self, | |||||
| inputs: Dict[str, Any], | |||||
| run_manager: Optional[CallbackManagerForChainRun] = None, | |||||
| ) -> Dict[str, Any]: | |||||
| input = inputs[self.input_key] | input = inputs[self.input_key] | ||||
| output = self.tool.run(input, self.verbose) | output = self.tool.run(input, self.verbose) | ||||
| return {self.output_key: output} | return {self.output_key: output} | ||||
| async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: | |||||
| async def _acall( | |||||
| self, | |||||
| inputs: Dict[str, Any], | |||||
| run_manager: Optional[AsyncCallbackManagerForChainRun] = None, | |||||
| ) -> Dict[str, Any]: | |||||
| """Run the logic of this chain and return the output.""" | """Run the logic of this chain and return the output.""" | ||||
| input = inputs[self.input_key] | input = inputs[self.input_key] | ||||
| output = await self.tool.arun(input, self.verbose) | output = await self.tool.arun(input, self.verbose) |
| import logging | import logging | ||||
| from typing import Optional, List, Union, Tuple | from typing import Optional, List, Union, Tuple | ||||
| from langchain.callbacks import CallbackManager | |||||
| from langchain.base_language import BaseLanguageModel | |||||
| from langchain.callbacks.base import BaseCallbackHandler | |||||
| from langchain.chat_models.base import BaseChatModel | from langchain.chat_models.base import BaseChatModel | ||||
| from langchain.llms import BaseLLM | from langchain.llms import BaseLLM | ||||
| from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage | |||||
| from langchain.schema import BaseMessage, HumanMessage | |||||
| from requests.exceptions import ChunkedEncodingError | from requests.exceptions import ChunkedEncodingError | ||||
| from core.constant import llm_constant | from core.constant import llm_constant | ||||
| from core.callback_handler.llm_callback_handler import LLMCallbackHandler | from core.callback_handler.llm_callback_handler import LLMCallbackHandler | ||||
| from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \ | from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \ | ||||
| DifyStdOutCallbackHandler | DifyStdOutCallbackHandler | ||||
| from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, PubHandler | |||||
| from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException | |||||
| from core.llm.error import LLMBadRequestError | from core.llm.error import LLMBadRequestError | ||||
| from core.llm.llm_builder import LLMBuilder | from core.llm.llm_builder import LLMBuilder | ||||
| from core.chain.main_chain_builder import MainChainBuilder | from core.chain.main_chain_builder import MainChainBuilder | ||||
| """ | """ | ||||
| errors: ProviderTokenNotInitError | errors: ProviderTokenNotInitError | ||||
| """ | """ | ||||
| cls.validate_query_tokens(app.tenant_id, app_model_config, query) | |||||
| memory = None | memory = None | ||||
| if conversation: | if conversation: | ||||
| # get memory of conversation (read-only) | # get memory of conversation (read-only) | ||||
| inputs = conversation.inputs | inputs = conversation.inputs | ||||
| rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens( | |||||
| mode=app.mode, | |||||
| tenant_id=app.tenant_id, | |||||
| app_model_config=app_model_config, | |||||
| query=query, | |||||
| inputs=inputs | |||||
| ) | |||||
| conversation_message_task = ConversationMessageTask( | conversation_message_task = ConversationMessageTask( | ||||
| task_id=task_id, | task_id=task_id, | ||||
| app=app, | app=app, | ||||
| main_chain = MainChainBuilder.to_langchain_components( | main_chain = MainChainBuilder.to_langchain_components( | ||||
| tenant_id=app.tenant_id, | tenant_id=app.tenant_id, | ||||
| agent_mode=app_model_config.agent_mode_dict, | agent_mode=app_model_config.agent_mode_dict, | ||||
| rest_tokens=rest_tokens_for_context_and_memory, | |||||
| memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None, | memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None, | ||||
| conversation_message_task=conversation_message_task | conversation_message_task=conversation_message_task | ||||
| ) | ) | ||||
| memory=memory | memory=memory | ||||
| ) | ) | ||||
| final_llm.callback_manager = cls.get_llm_callback_manager(final_llm, streaming, conversation_message_task) | |||||
| final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task) | |||||
| cls.recale_llm_max_tokens( | cls.recale_llm_max_tokens( | ||||
| final_llm=final_llm, | final_llm=final_llm, | ||||
| return messages, ['\nHuman:'] | return messages, ['\nHuman:'] | ||||
| @classmethod | @classmethod | ||||
| def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI], | |||||
| streaming: bool, | |||||
| conversation_message_task: ConversationMessageTask) -> CallbackManager: | |||||
| def get_llm_callbacks(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI], | |||||
| streaming: bool, | |||||
| conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]: | |||||
| llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task) | llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task) | ||||
| if streaming: | if streaming: | ||||
| callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()] | |||||
| return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()] | |||||
| else: | else: | ||||
| callback_handlers = [llm_callback_handler, DifyStdOutCallbackHandler()] | |||||
| return CallbackManager(callback_handlers) | |||||
| return [llm_callback_handler, DifyStdOutCallbackHandler()] | |||||
| @classmethod | @classmethod | ||||
| def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, | def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, | ||||
| return memory | return memory | ||||
| @classmethod | @classmethod | ||||
| def validate_query_tokens(cls, tenant_id: str, app_model_config: AppModelConfig, query: str): | |||||
| def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig, | |||||
| query: str, inputs: dict) -> int: | |||||
| llm = LLMBuilder.to_llm_from_model( | llm = LLMBuilder.to_llm_from_model( | ||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| model=app_model_config.model_dict | model=app_model_config.model_dict | ||||
| model_limited_tokens = llm_constant.max_context_token_length[llm.model_name] | model_limited_tokens = llm_constant.max_context_token_length[llm.model_name] | ||||
| max_tokens = llm.max_tokens | max_tokens = llm.max_tokens | ||||
| if model_limited_tokens - max_tokens - llm.get_num_tokens(query) < 0: | |||||
| raise LLMBadRequestError("Query is too long") | |||||
| # get prompt without memory and context | |||||
| prompt, _ = cls.get_main_llm_prompt( | |||||
| mode=mode, | |||||
| llm=llm, | |||||
| pre_prompt=app_model_config.pre_prompt, | |||||
| query=query, | |||||
| inputs=inputs, | |||||
| chain_output=None, | |||||
| memory=None | |||||
| ) | |||||
| prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \ | |||||
| else llm.get_num_tokens_from_messages(prompt) | |||||
| rest_tokens = model_limited_tokens - max_tokens - prompt_tokens | |||||
| if rest_tokens < 0: | |||||
| raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, " | |||||
| "or shrink the max token, or switch to a llm with a larger token limit size.") | |||||
| return rest_tokens | |||||
| @classmethod | @classmethod | ||||
| def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI], | def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI], | ||||
| streaming=streaming | streaming=streaming | ||||
| ) | ) | ||||
| llm.callback_manager = cls.get_llm_callback_manager(llm, streaming, conversation_message_task) | |||||
| llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task) | |||||
| cls.recale_llm_max_tokens( | cls.recale_llm_max_tokens( | ||||
| final_llm=llm, | final_llm=llm, |
| if not user: | if not user: | ||||
| raise ValueError("user is required") | raise ValueError("user is required") | ||||
| user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id | |||||
| user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id) | |||||
| return "generate_result:{}-{}".format(user_str, task_id) | return "generate_result:{}-{}".format(user_str, task_id) | ||||
| @classmethod | @classmethod | ||||
| def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str): | def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str): | ||||
| user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id | |||||
| user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id) | |||||
| return "generate_result_stopped:{}-{}".format(user_str, task_id) | return "generate_result_stopped:{}-{}".format(user_str, task_id) | ||||
| def pub_text(self, text: str): | def pub_text(self, text: str): | ||||
| 'event': 'message', | 'event': 'message', | ||||
| 'data': { | 'data': { | ||||
| 'task_id': self._task_id, | 'task_id': self._task_id, | ||||
| 'message_id': self._message.id, | |||||
| 'message_id': str(self._message.id), | |||||
| 'text': text, | 'text': text, | ||||
| 'mode': self._conversation.mode, | 'mode': self._conversation.mode, | ||||
| 'conversation_id': self._conversation.id | |||||
| 'conversation_id': str(self._conversation.id) | |||||
| } | } | ||||
| } | } | ||||
| import tempfile | |||||
| from pathlib import Path | |||||
| from typing import List, Union | |||||
| from langchain.document_loaders import TextLoader, Docx2txtLoader | |||||
| from langchain.schema import Document | |||||
| from core.data_loader.loader.csv import CSVLoader | |||||
| from core.data_loader.loader.excel import ExcelLoader | |||||
| from core.data_loader.loader.html import HTMLLoader | |||||
| from core.data_loader.loader.markdown import MarkdownLoader | |||||
| from core.data_loader.loader.pdf import PdfLoader | |||||
| from extensions.ext_storage import storage | |||||
| from models.model import UploadFile | |||||
| class FileExtractor: | |||||
| @classmethod | |||||
| def load(cls, upload_file: UploadFile, return_text: bool = False) -> Union[List[Document] | str]: | |||||
| with tempfile.TemporaryDirectory() as temp_dir: | |||||
| suffix = Path(upload_file.key).suffix | |||||
| file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" | |||||
| storage.download(upload_file.key, file_path) | |||||
| input_file = Path(file_path) | |||||
| delimiter = '\n' | |||||
| if input_file.suffix == '.xlsx': | |||||
| loader = ExcelLoader(file_path) | |||||
| elif input_file.suffix == '.pdf': | |||||
| loader = PdfLoader(file_path, upload_file=upload_file) | |||||
| elif input_file.suffix in ['.md', '.markdown']: | |||||
| loader = MarkdownLoader(file_path, autodetect_encoding=True) | |||||
| elif input_file.suffix in ['.htm', '.html']: | |||||
| loader = HTMLLoader(file_path) | |||||
| elif input_file.suffix == '.docx': | |||||
| loader = Docx2txtLoader(file_path) | |||||
| elif input_file.suffix == '.csv': | |||||
| loader = CSVLoader(file_path, autodetect_encoding=True) | |||||
| else: | |||||
| # txt | |||||
| loader = TextLoader(file_path, autodetect_encoding=True) | |||||
| return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load() |
| import logging | |||||
| from typing import Optional, Dict, List | |||||
| from langchain.document_loaders import CSVLoader as LCCSVLoader | |||||
| from langchain.document_loaders.helpers import detect_file_encodings | |||||
| from models.dataset import Document | |||||
| logger = logging.getLogger(__name__) | |||||
| class CSVLoader(LCCSVLoader): | |||||
| def __init__( | |||||
| self, | |||||
| file_path: str, | |||||
| source_column: Optional[str] = None, | |||||
| csv_args: Optional[Dict] = None, | |||||
| encoding: Optional[str] = None, | |||||
| autodetect_encoding: bool = True, | |||||
| ): | |||||
| self.file_path = file_path | |||||
| self.source_column = source_column | |||||
| self.encoding = encoding | |||||
| self.csv_args = csv_args or {} | |||||
| self.autodetect_encoding = autodetect_encoding | |||||
| def load(self) -> List[Document]: | |||||
| """Load data into document objects.""" | |||||
| try: | |||||
| with open(self.file_path, newline="", encoding=self.encoding) as csvfile: | |||||
| docs = self._read_from_file(csvfile) | |||||
| except UnicodeDecodeError as e: | |||||
| if self.autodetect_encoding: | |||||
| detected_encodings = detect_file_encodings(self.file_path) | |||||
| for encoding in detected_encodings: | |||||
| logger.debug("Trying encoding: ", encoding.encoding) | |||||
| try: | |||||
| with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile: | |||||
| docs = self._read_from_file(csvfile) | |||||
| break | |||||
| except UnicodeDecodeError: | |||||
| continue | |||||
| else: | |||||
| raise RuntimeError(f"Error loading {self.file_path}") from e | |||||
| return docs | |||||
| def _read_from_file(self, csvfile): | |||||
| docs = [] | |||||
| csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore | |||||
| for i, row in enumerate(csv_reader): | |||||
| content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items()) | |||||
| try: | |||||
| source = ( | |||||
| row[self.source_column] | |||||
| if self.source_column is not None | |||||
| else '' | |||||
| ) | |||||
| except KeyError: | |||||
| raise ValueError( | |||||
| f"Source column '{self.source_column}' not found in CSV file." | |||||
| ) | |||||
| metadata = {"source": source, "row": i} | |||||
| doc = Document(page_content=content, metadata=metadata) | |||||
| docs.append(doc) | |||||
| return docs |
| import json | |||||
| import logging | |||||
| from typing import List | |||||
| from langchain.document_loaders.base import BaseLoader | |||||
| from langchain.schema import Document | |||||
| from openpyxl.reader.excel import load_workbook | |||||
| logger = logging.getLogger(__name__) | |||||
| class ExcelLoader(BaseLoader): | |||||
| """Load xlxs files. | |||||
| Args: | |||||
| file_path: Path to the file to load. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| file_path: str | |||||
| ): | |||||
| """Initialize with file path.""" | |||||
| self._file_path = file_path | |||||
| def load(self) -> List[Document]: | |||||
| data = [] | |||||
| keys = [] | |||||
| wb = load_workbook(filename=self._file_path, read_only=True) | |||||
| # loop over all sheets | |||||
| for sheet in wb: | |||||
| for row in sheet.iter_rows(values_only=True): | |||||
| if all(v is None for v in row): | |||||
| continue | |||||
| if keys == []: | |||||
| keys = list(map(str, row)) | |||||
| else: | |||||
| row_dict = dict(zip(keys, row)) | |||||
| row_dict = {k: v for k, v in row_dict.items() if v} | |||||
| data.append(json.dumps(row_dict, ensure_ascii=False)) | |||||
| return [Document(page_content='\n\n'.join(data))] |
| import logging | |||||
| from typing import List | |||||
| from bs4 import BeautifulSoup | |||||
| from langchain.document_loaders.base import BaseLoader | |||||
| from langchain.schema import Document | |||||
| logger = logging.getLogger(__name__) | |||||
| class HTMLLoader(BaseLoader): | |||||
| """Load html files. | |||||
| Args: | |||||
| file_path: Path to the file to load. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| file_path: str | |||||
| ): | |||||
| """Initialize with file path.""" | |||||
| self._file_path = file_path | |||||
| def load(self) -> List[Document]: | |||||
| return [Document(page_content=self._load_as_text())] | |||||
| def _load_as_text(self) -> str: | |||||
| with open(self._file_path, "rb") as fp: | |||||
| soup = BeautifulSoup(fp, 'html.parser') | |||||
| text = soup.get_text() | |||||
| text = text.strip() if text else '' | |||||
| return text |
| import logging | |||||
| import re | |||||
| from typing import Optional, List, Tuple, cast | |||||
| from langchain.document_loaders.base import BaseLoader | |||||
| from langchain.document_loaders.helpers import detect_file_encodings | |||||
| from langchain.schema import Document | |||||
| logger = logging.getLogger(__name__) | |||||
| class MarkdownLoader(BaseLoader): | |||||
| """Load md files. | |||||
| Args: | |||||
| file_path: Path to the file to load. | |||||
| remove_hyperlinks: Whether to remove hyperlinks from the text. | |||||
| remove_images: Whether to remove images from the text. | |||||
| encoding: File encoding to use. If `None`, the file will be loaded | |||||
| with the default system encoding. | |||||
| autodetect_encoding: Whether to try to autodetect the file encoding | |||||
| if the specified encoding fails. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| file_path: str, | |||||
| remove_hyperlinks: bool = True, | |||||
| remove_images: bool = True, | |||||
| encoding: Optional[str] = None, | |||||
| autodetect_encoding: bool = True, | |||||
| ): | |||||
| """Initialize with file path.""" | |||||
| self._file_path = file_path | |||||
| self._remove_hyperlinks = remove_hyperlinks | |||||
| self._remove_images = remove_images | |||||
| self._encoding = encoding | |||||
| self._autodetect_encoding = autodetect_encoding | |||||
| def load(self) -> List[Document]: | |||||
| tups = self.parse_tups(self._file_path) | |||||
| documents = [] | |||||
| for header, value in tups: | |||||
| value = value.strip() | |||||
| if header is None: | |||||
| documents.append(Document(page_content=value)) | |||||
| else: | |||||
| documents.append(Document(page_content=f"\n\n{header}\n{value}")) | |||||
| return documents | |||||
| def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]: | |||||
| """Convert a markdown file to a dictionary. | |||||
| The keys are the headers and the values are the text under each header. | |||||
| """ | |||||
| markdown_tups: List[Tuple[Optional[str], str]] = [] | |||||
| lines = markdown_text.split("\n") | |||||
| current_header = None | |||||
| current_text = "" | |||||
| for line in lines: | |||||
| header_match = re.match(r"^#+\s", line) | |||||
| if header_match: | |||||
| if current_header is not None: | |||||
| markdown_tups.append((current_header, current_text)) | |||||
| current_header = line | |||||
| current_text = "" | |||||
| else: | |||||
| current_text += line + "\n" | |||||
| markdown_tups.append((current_header, current_text)) | |||||
| if current_header is not None: | |||||
| # pass linting, assert keys are defined | |||||
| markdown_tups = [ | |||||
| (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) | |||||
| for key, value in markdown_tups | |||||
| ] | |||||
| else: | |||||
| markdown_tups = [ | |||||
| (key, re.sub("\n", "", value)) for key, value in markdown_tups | |||||
| ] | |||||
| return markdown_tups | |||||
| def remove_images(self, content: str) -> str: | |||||
| """Get a dictionary of a markdown file from its path.""" | |||||
| pattern = r"!{1}\[\[(.*)\]\]" | |||||
| content = re.sub(pattern, "", content) | |||||
| return content | |||||
| def remove_hyperlinks(self, content: str) -> str: | |||||
| """Get a dictionary of a markdown file from its path.""" | |||||
| pattern = r"\[(.*?)\]\((.*?)\)" | |||||
| content = re.sub(pattern, r"\1", content) | |||||
| return content | |||||
| def parse_tups(self, filepath: str) -> List[Tuple[Optional[str], str]]: | |||||
| """Parse file into tuples.""" | |||||
| content = "" | |||||
| try: | |||||
| with open(filepath, "r", encoding=self._encoding) as f: | |||||
| content = f.read() | |||||
| except UnicodeDecodeError as e: | |||||
| if self._autodetect_encoding: | |||||
| detected_encodings = detect_file_encodings(filepath) | |||||
| for encoding in detected_encodings: | |||||
| logger.debug("Trying encoding: ", encoding.encoding) | |||||
| try: | |||||
| with open(filepath, encoding=encoding.encoding) as f: | |||||
| content = f.read() | |||||
| break | |||||
| except UnicodeDecodeError: | |||||
| continue | |||||
| else: | |||||
| raise RuntimeError(f"Error loading {filepath}") from e | |||||
| except Exception as e: | |||||
| raise RuntimeError(f"Error loading {filepath}") from e | |||||
| if self._remove_hyperlinks: | |||||
| content = self.remove_hyperlinks(content) | |||||
| if self._remove_images: | |||||
| content = self.remove_images(content) | |||||
| return self.markdown_to_tups(content) |
| """Notion reader.""" | |||||
| import json | import json | ||||
| import logging | import logging | ||||
| import os | |||||
| from datetime import datetime | |||||
| from typing import Any, Dict, List, Optional | |||||
| from typing import List, Dict, Any, Optional | |||||
| import requests # type: ignore | |||||
| import requests | |||||
| from flask import current_app | |||||
| from langchain.document_loaders.base import BaseLoader | |||||
| from langchain.schema import Document | |||||
| from llama_index.readers.base import BaseReader | |||||
| from llama_index.readers.schema.base import Document | |||||
| from extensions.ext_database import db | |||||
| from models.dataset import Document as DocumentModel | |||||
| from models.source import DataSourceBinding | |||||
| logger = logging.getLogger(__name__) | |||||
| INTEGRATION_TOKEN_NAME = "NOTION_INTEGRATION_TOKEN" | |||||
| BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children" | BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children" | ||||
| DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query" | DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query" | ||||
| SEARCH_URL = "https://api.notion.com/v1/search" | SEARCH_URL = "https://api.notion.com/v1/search" | ||||
| RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}" | RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}" | ||||
| RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}" | RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}" | ||||
| HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3'] | HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3'] | ||||
| logger = logging.getLogger(__name__) | |||||
| # TODO: Notion DB reader coming soon! | |||||
| class NotionPageReader(BaseReader): | |||||
| """Notion Page reader. | |||||
| Reads a set of Notion pages. | |||||
| Args: | |||||
| integration_token (str): Notion integration token. | |||||
| """ | |||||
| def __init__(self, integration_token: Optional[str] = None) -> None: | |||||
| """Initialize with parameters.""" | |||||
| if integration_token is None: | |||||
| integration_token = os.getenv(INTEGRATION_TOKEN_NAME) | |||||
| class NotionLoader(BaseLoader): | |||||
| def __init__( | |||||
| self, | |||||
| notion_access_token: str, | |||||
| notion_workspace_id: str, | |||||
| notion_obj_id: str, | |||||
| notion_page_type: str, | |||||
| document_model: Optional[DocumentModel] = None | |||||
| ): | |||||
| self._document_model = document_model | |||||
| self._notion_workspace_id = notion_workspace_id | |||||
| self._notion_obj_id = notion_obj_id | |||||
| self._notion_page_type = notion_page_type | |||||
| self._notion_access_token = notion_access_token | |||||
| if not self._notion_access_token: | |||||
| integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN') | |||||
| if integration_token is None: | if integration_token is None: | ||||
| raise ValueError( | raise ValueError( | ||||
| "Must specify `integration_token` or set environment " | "Must specify `integration_token` or set environment " | ||||
| "variable `NOTION_INTEGRATION_TOKEN`." | "variable `NOTION_INTEGRATION_TOKEN`." | ||||
| ) | ) | ||||
| self.token = integration_token | |||||
| self.headers = { | |||||
| "Authorization": "Bearer " + self.token, | |||||
| "Content-Type": "application/json", | |||||
| "Notion-Version": "2022-06-28", | |||||
| } | |||||
| def _read_block(self, block_id: str, num_tabs: int = 0) -> str: | |||||
| """Read a block.""" | |||||
| done = False | |||||
| self._notion_access_token = integration_token | |||||
| @classmethod | |||||
| def from_document(cls, document_model: DocumentModel): | |||||
| data_source_info = document_model.data_source_info_dict | |||||
| if not data_source_info or 'notion_page_id' not in data_source_info \ | |||||
| or 'notion_workspace_id' not in data_source_info: | |||||
| raise ValueError("no notion page found") | |||||
| notion_workspace_id = data_source_info['notion_workspace_id'] | |||||
| notion_obj_id = data_source_info['notion_page_id'] | |||||
| notion_page_type = data_source_info['type'] | |||||
| notion_access_token = cls._get_access_token(document_model.tenant_id, notion_workspace_id) | |||||
| return cls( | |||||
| notion_access_token=notion_access_token, | |||||
| notion_workspace_id=notion_workspace_id, | |||||
| notion_obj_id=notion_obj_id, | |||||
| notion_page_type=notion_page_type, | |||||
| document_model=document_model | |||||
| ) | |||||
| def load(self) -> List[Document]: | |||||
| self.update_last_edited_time( | |||||
| self._document_model | |||||
| ) | |||||
| text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type) | |||||
| return text_docs | |||||
| def _load_data_as_documents( | |||||
| self, notion_obj_id: str, notion_page_type: str | |||||
| ) -> List[Document]: | |||||
| docs = [] | |||||
| if notion_page_type == 'database': | |||||
| # get all the pages in the database | |||||
| page_text = self._get_notion_database_data(notion_obj_id) | |||||
| docs.append(Document(page_content=page_text)) | |||||
| elif notion_page_type == 'page': | |||||
| page_text_list = self._get_notion_block_data(notion_obj_id) | |||||
| for page_text in page_text_list: | |||||
| docs.append(Document(page_content=page_text)) | |||||
| else: | |||||
| raise ValueError("notion page type not supported") | |||||
| return docs | |||||
| def _get_notion_database_data( | |||||
| self, database_id: str, query_dict: Dict[str, Any] = {} | |||||
| ) -> str: | |||||
| """Get all the pages from a Notion database.""" | |||||
| res = requests.post( | |||||
| DATABASE_URL_TMPL.format(database_id=database_id), | |||||
| headers={ | |||||
| "Authorization": "Bearer " + self._notion_access_token, | |||||
| "Content-Type": "application/json", | |||||
| "Notion-Version": "2022-06-28", | |||||
| }, | |||||
| json=query_dict, | |||||
| ) | |||||
| data = res.json() | |||||
| database_content_list = [] | |||||
| if 'results' not in data or data["results"] is None: | |||||
| return "" | |||||
| for result in data["results"]: | |||||
| properties = result['properties'] | |||||
| data = {} | |||||
| for property_name, property_value in properties.items(): | |||||
| type = property_value['type'] | |||||
| if type == 'multi_select': | |||||
| value = [] | |||||
| multi_select_list = property_value[type] | |||||
| for multi_select in multi_select_list: | |||||
| value.append(multi_select['name']) | |||||
| elif type == 'rich_text' or type == 'title': | |||||
| if len(property_value[type]) > 0: | |||||
| value = property_value[type][0]['plain_text'] | |||||
| else: | |||||
| value = '' | |||||
| elif type == 'select' or type == 'status': | |||||
| if property_value[type]: | |||||
| value = property_value[type]['name'] | |||||
| else: | |||||
| value = '' | |||||
| else: | |||||
| value = property_value[type] | |||||
| data[property_name] = value | |||||
| database_content_list.append(json.dumps(data, ensure_ascii=False)) | |||||
| return "\n\n".join(database_content_list) | |||||
| def _get_notion_block_data(self, page_id: str) -> List[str]: | |||||
| result_lines_arr = [] | result_lines_arr = [] | ||||
| cur_block_id = block_id | |||||
| while not done: | |||||
| cur_block_id = page_id | |||||
| while True: | |||||
| block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) | block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) | ||||
| query_dict: Dict[str, Any] = {} | query_dict: Dict[str, Any] = {} | ||||
| res = requests.request( | res = requests.request( | ||||
| "GET", block_url, headers=self.headers, json=query_dict | |||||
| "GET", | |||||
| block_url, | |||||
| headers={ | |||||
| "Authorization": "Bearer " + self._notion_access_token, | |||||
| "Content-Type": "application/json", | |||||
| "Notion-Version": "2022-06-28", | |||||
| }, | |||||
| json=query_dict | |||||
| ) | ) | ||||
| data = res.json() | data = res.json() | ||||
| if 'results' not in data or data["results"] is None: | |||||
| done = True | |||||
| break | |||||
| # current block's heading | |||||
| heading = '' | heading = '' | ||||
| for result in data["results"]: | for result in data["results"]: | ||||
| result_type = result["type"] | result_type = result["type"] | ||||
| if result_type == 'table': | if result_type == 'table': | ||||
| result_block_id = result["id"] | result_block_id = result["id"] | ||||
| text = self._read_table_rows(result_block_id) | text = self._read_table_rows(result_block_id) | ||||
| text += "\n\n" | |||||
| result_lines_arr.append(text) | result_lines_arr.append(text) | ||||
| else: | else: | ||||
| if "rich_text" in result_obj: | if "rich_text" in result_obj: | ||||
| # skip if doesn't have text object | # skip if doesn't have text object | ||||
| if "text" in rich_text: | if "text" in rich_text: | ||||
| text = rich_text["text"]["content"] | text = rich_text["text"]["content"] | ||||
| prefix = "\t" * num_tabs | |||||
| cur_result_text_arr.append(prefix + text) | |||||
| cur_result_text_arr.append(text) | |||||
| if result_type in HEADING_TYPE: | if result_type in HEADING_TYPE: | ||||
| heading = text | heading = text | ||||
| result_block_id = result["id"] | result_block_id = result["id"] | ||||
| has_children = result["has_children"] | has_children = result["has_children"] | ||||
| block_type = result["type"] | block_type = result["type"] | ||||
| if has_children and block_type != 'child_page': | if has_children and block_type != 'child_page': | ||||
| children_text = self._read_block( | children_text = self._read_block( | ||||
| result_block_id, num_tabs=num_tabs + 1 | |||||
| result_block_id, num_tabs=1 | |||||
| ) | ) | ||||
| cur_result_text_arr.append(children_text) | cur_result_text_arr.append(children_text) | ||||
| cur_result_text = "\n".join(cur_result_text_arr) | cur_result_text = "\n".join(cur_result_text_arr) | ||||
| cur_result_text += "\n\n" | |||||
| if result_type in HEADING_TYPE: | if result_type in HEADING_TYPE: | ||||
| result_lines_arr.append(cur_result_text) | result_lines_arr.append(cur_result_text) | ||||
| else: | else: | ||||
| result_lines_arr.append(f'{heading}\n{cur_result_text}') | result_lines_arr.append(f'{heading}\n{cur_result_text}') | ||||
| if data["next_cursor"] is None: | if data["next_cursor"] is None: | ||||
| done = True | |||||
| break | |||||
| else: | |||||
| cur_block_id = data["next_cursor"] | |||||
| result_lines = "\n".join(result_lines_arr) | |||||
| return result_lines | |||||
| def _read_table_rows(self, block_id: str) -> str: | |||||
| """Read table rows.""" | |||||
| done = False | |||||
| result_lines_arr = [] | |||||
| cur_block_id = block_id | |||||
| while not done: | |||||
| block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) | |||||
| query_dict: Dict[str, Any] = {} | |||||
| res = requests.request( | |||||
| "GET", block_url, headers=self.headers, json=query_dict | |||||
| ) | |||||
| data = res.json() | |||||
| # get table headers text | |||||
| table_header_cell_texts = [] | |||||
| tabel_header_cells = data["results"][0]['table_row']['cells'] | |||||
| for tabel_header_cell in tabel_header_cells: | |||||
| if tabel_header_cell: | |||||
| for table_header_cell_text in tabel_header_cell: | |||||
| text = table_header_cell_text["text"]["content"] | |||||
| table_header_cell_texts.append(text) | |||||
| # get table columns text and format | |||||
| results = data["results"] | |||||
| for i in range(len(results)-1): | |||||
| column_texts = [] | |||||
| tabel_column_cells = data["results"][i+1]['table_row']['cells'] | |||||
| for j in range(len(tabel_column_cells)): | |||||
| if tabel_column_cells[j]: | |||||
| for table_column_cell_text in tabel_column_cells[j]: | |||||
| column_text = table_column_cell_text["text"]["content"] | |||||
| column_texts.append(f'{table_header_cell_texts[j]}:{column_text}') | |||||
| cur_result_text = "\n".join(column_texts) | |||||
| result_lines_arr.append(cur_result_text) | |||||
| if data["next_cursor"] is None: | |||||
| done = True | |||||
| break | break | ||||
| else: | else: | ||||
| cur_block_id = data["next_cursor"] | cur_block_id = data["next_cursor"] | ||||
| return result_lines_arr | |||||
| result_lines = "\n".join(result_lines_arr) | |||||
| return result_lines | |||||
| def _read_parent_blocks(self, block_id: str, num_tabs: int = 0) -> List[str]: | |||||
| def _read_block(self, block_id: str, num_tabs: int = 0) -> str: | |||||
| """Read a block.""" | """Read a block.""" | ||||
| done = False | |||||
| result_lines_arr = [] | result_lines_arr = [] | ||||
| cur_block_id = block_id | cur_block_id = block_id | ||||
| while not done: | |||||
| while True: | |||||
| block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) | block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) | ||||
| query_dict: Dict[str, Any] = {} | query_dict: Dict[str, Any] = {} | ||||
| res = requests.request( | res = requests.request( | ||||
| "GET", block_url, headers=self.headers, json=query_dict | |||||
| "GET", | |||||
| block_url, | |||||
| headers={ | |||||
| "Authorization": "Bearer " + self._notion_access_token, | |||||
| "Content-Type": "application/json", | |||||
| "Notion-Version": "2022-06-28", | |||||
| }, | |||||
| json=query_dict | |||||
| ) | ) | ||||
| data = res.json() | data = res.json() | ||||
| # current block's heading | |||||
| if 'results' not in data or data["results"] is None: | |||||
| break | |||||
| heading = '' | heading = '' | ||||
| for result in data["results"]: | for result in data["results"]: | ||||
| result_type = result["type"] | result_type = result["type"] | ||||
| if result_type == 'table': | if result_type == 'table': | ||||
| result_block_id = result["id"] | result_block_id = result["id"] | ||||
| text = self._read_table_rows(result_block_id) | text = self._read_table_rows(result_block_id) | ||||
| text += "\n\n" | |||||
| result_lines_arr.append(text) | result_lines_arr.append(text) | ||||
| else: | else: | ||||
| if "rich_text" in result_obj: | if "rich_text" in result_obj: | ||||
| # skip if doesn't have text object | # skip if doesn't have text object | ||||
| if "text" in rich_text: | if "text" in rich_text: | ||||
| text = rich_text["text"]["content"] | text = rich_text["text"]["content"] | ||||
| cur_result_text_arr.append(text) | |||||
| prefix = "\t" * num_tabs | |||||
| cur_result_text_arr.append(prefix + text) | |||||
| if result_type in HEADING_TYPE: | if result_type in HEADING_TYPE: | ||||
| heading = text | heading = text | ||||
| result_block_id = result["id"] | result_block_id = result["id"] | ||||
| has_children = result["has_children"] | has_children = result["has_children"] | ||||
| block_type = result["type"] | block_type = result["type"] | ||||
| cur_result_text_arr.append(children_text) | cur_result_text_arr.append(children_text) | ||||
| cur_result_text = "\n".join(cur_result_text_arr) | cur_result_text = "\n".join(cur_result_text_arr) | ||||
| cur_result_text += "\n\n" | |||||
| if result_type in HEADING_TYPE: | if result_type in HEADING_TYPE: | ||||
| result_lines_arr.append(cur_result_text) | result_lines_arr.append(cur_result_text) | ||||
| else: | else: | ||||
| result_lines_arr.append(f'{heading}\n{cur_result_text}') | result_lines_arr.append(f'{heading}\n{cur_result_text}') | ||||
| if data["next_cursor"] is None: | if data["next_cursor"] is None: | ||||
| done = True | |||||
| break | break | ||||
| else: | else: | ||||
| cur_block_id = data["next_cursor"] | cur_block_id = data["next_cursor"] | ||||
| return result_lines_arr | |||||
| def read_page(self, page_id: str) -> str: | |||||
| """Read a page.""" | |||||
| return self._read_block(page_id) | |||||
| def read_page_as_documents(self, page_id: str) -> List[str]: | |||||
| """Read a page as documents.""" | |||||
| return self._read_parent_blocks(page_id) | |||||
| def query_database_data( | |||||
| self, database_id: str, query_dict: Dict[str, Any] = {} | |||||
| ) -> str: | |||||
| """Get all the pages from a Notion database.""" | |||||
| res = requests.post\ | |||||
| ( | |||||
| DATABASE_URL_TMPL.format(database_id=database_id), | |||||
| headers=self.headers, | |||||
| json=query_dict, | |||||
| ) | |||||
| data = res.json() | |||||
| database_content_list = [] | |||||
| if 'results' not in data or data["results"] is None: | |||||
| return "" | |||||
| for result in data["results"]: | |||||
| properties = result['properties'] | |||||
| data = {} | |||||
| for property_name, property_value in properties.items(): | |||||
| type = property_value['type'] | |||||
| if type == 'multi_select': | |||||
| value = [] | |||||
| multi_select_list = property_value[type] | |||||
| for multi_select in multi_select_list: | |||||
| value.append(multi_select['name']) | |||||
| elif type == 'rich_text' or type == 'title': | |||||
| if len(property_value[type]) > 0: | |||||
| value = property_value[type][0]['plain_text'] | |||||
| else: | |||||
| value = '' | |||||
| elif type == 'select' or type == 'status': | |||||
| if property_value[type]: | |||||
| value = property_value[type]['name'] | |||||
| else: | |||||
| value = '' | |||||
| else: | |||||
| value = property_value[type] | |||||
| data[property_name] = value | |||||
| database_content_list.append(json.dumps(data, ensure_ascii=False)) | |||||
| return "\n\n".join(database_content_list) | |||||
| def query_database( | |||||
| self, database_id: str, query_dict: Dict[str, Any] = {} | |||||
| ) -> List[str]: | |||||
| """Get all the pages from a Notion database.""" | |||||
| res = requests.post\ | |||||
| ( | |||||
| DATABASE_URL_TMPL.format(database_id=database_id), | |||||
| headers=self.headers, | |||||
| json=query_dict, | |||||
| ) | |||||
| data = res.json() | |||||
| page_ids = [] | |||||
| for result in data["results"]: | |||||
| page_id = result["id"] | |||||
| page_ids.append(page_id) | |||||
| return page_ids | |||||
| result_lines = "\n".join(result_lines_arr) | |||||
| return result_lines | |||||
| def search(self, query: str) -> List[str]: | |||||
| """Search Notion page given a text query.""" | |||||
| def _read_table_rows(self, block_id: str) -> str: | |||||
| """Read table rows.""" | |||||
| done = False | done = False | ||||
| next_cursor: Optional[str] = None | |||||
| page_ids = [] | |||||
| result_lines_arr = [] | |||||
| cur_block_id = block_id | |||||
| while not done: | while not done: | ||||
| query_dict = { | |||||
| "query": query, | |||||
| } | |||||
| if next_cursor is not None: | |||||
| query_dict["start_cursor"] = next_cursor | |||||
| res = requests.post(SEARCH_URL, headers=self.headers, json=query_dict) | |||||
| block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) | |||||
| query_dict: Dict[str, Any] = {} | |||||
| res = requests.request( | |||||
| "GET", | |||||
| block_url, | |||||
| headers={ | |||||
| "Authorization": "Bearer " + self._notion_access_token, | |||||
| "Content-Type": "application/json", | |||||
| "Notion-Version": "2022-06-28", | |||||
| }, | |||||
| json=query_dict | |||||
| ) | |||||
| data = res.json() | data = res.json() | ||||
| for result in data["results"]: | |||||
| page_id = result["id"] | |||||
| page_ids.append(page_id) | |||||
| # get table headers text | |||||
| table_header_cell_texts = [] | |||||
| tabel_header_cells = data["results"][0]['table_row']['cells'] | |||||
| for tabel_header_cell in tabel_header_cells: | |||||
| if tabel_header_cell: | |||||
| for table_header_cell_text in tabel_header_cell: | |||||
| text = table_header_cell_text["text"]["content"] | |||||
| table_header_cell_texts.append(text) | |||||
| # get table columns text and format | |||||
| results = data["results"] | |||||
| for i in range(len(results) - 1): | |||||
| column_texts = [] | |||||
| tabel_column_cells = data["results"][i + 1]['table_row']['cells'] | |||||
| for j in range(len(tabel_column_cells)): | |||||
| if tabel_column_cells[j]: | |||||
| for table_column_cell_text in tabel_column_cells[j]: | |||||
| column_text = table_column_cell_text["text"]["content"] | |||||
| column_texts.append(f'{table_header_cell_texts[j]}:{column_text}') | |||||
| cur_result_text = "\n".join(column_texts) | |||||
| result_lines_arr.append(cur_result_text) | |||||
| if data["next_cursor"] is None: | if data["next_cursor"] is None: | ||||
| done = True | done = True | ||||
| break | break | ||||
| else: | else: | ||||
| next_cursor = data["next_cursor"] | |||||
| return page_ids | |||||
| cur_block_id = data["next_cursor"] | |||||
| def load_data( | |||||
| self, page_ids: List[str] = [], database_id: Optional[str] = None | |||||
| ) -> List[Document]: | |||||
| """Load data from the input directory. | |||||
| result_lines = "\n".join(result_lines_arr) | |||||
| return result_lines | |||||
| Args: | |||||
| page_ids (List[str]): List of page ids to load. | |||||
| def update_last_edited_time(self, document_model: DocumentModel): | |||||
| if not document_model: | |||||
| return | |||||
| Returns: | |||||
| List[Document]: List of documents. | |||||
| last_edited_time = self.get_notion_last_edited_time() | |||||
| data_source_info = document_model.data_source_info_dict | |||||
| data_source_info['last_edited_time'] = last_edited_time | |||||
| update_params = { | |||||
| DocumentModel.data_source_info: json.dumps(data_source_info) | |||||
| } | |||||
| """ | |||||
| if not page_ids and not database_id: | |||||
| raise ValueError("Must specify either `page_ids` or `database_id`.") | |||||
| docs = [] | |||||
| if database_id is not None: | |||||
| # get all the pages in the database | |||||
| page_ids = self.query_database(database_id) | |||||
| for page_id in page_ids: | |||||
| page_text = self.read_page(page_id) | |||||
| docs.append(Document(page_text)) | |||||
| else: | |||||
| for page_id in page_ids: | |||||
| page_text = self.read_page(page_id) | |||||
| docs.append(Document(page_text)) | |||||
| DocumentModel.query.filter_by(id=document_model.id).update(update_params) | |||||
| db.session.commit() | |||||
| return docs | |||||
| def load_data_as_documents( | |||||
| self, page_ids: List[str] = [], database_id: Optional[str] = None | |||||
| ) -> List[Document]: | |||||
| if not page_ids and not database_id: | |||||
| raise ValueError("Must specify either `page_ids` or `database_id`.") | |||||
| docs = [] | |||||
| if database_id is not None: | |||||
| # get all the pages in the database | |||||
| page_text = self.query_database_data(database_id) | |||||
| docs.append(Document(page_text)) | |||||
| def get_notion_last_edited_time(self) -> str: | |||||
| obj_id = self._notion_obj_id | |||||
| page_type = self._notion_page_type | |||||
| if page_type == 'database': | |||||
| retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id) | |||||
| else: | else: | ||||
| for page_id in page_ids: | |||||
| page_text_list = self.read_page_as_documents(page_id) | |||||
| for page_text in page_text_list: | |||||
| docs.append(Document(page_text)) | |||||
| retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id) | |||||
| return docs | |||||
| def get_page_last_edited_time(self, page_id: str) -> str: | |||||
| retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=page_id) | |||||
| query_dict: Dict[str, Any] = {} | query_dict: Dict[str, Any] = {} | ||||
| res = requests.request( | res = requests.request( | ||||
| "GET", retrieve_page_url, headers=self.headers, json=query_dict | |||||
| "GET", | |||||
| retrieve_page_url, | |||||
| headers={ | |||||
| "Authorization": "Bearer " + self._notion_access_token, | |||||
| "Content-Type": "application/json", | |||||
| "Notion-Version": "2022-06-28", | |||||
| }, | |||||
| json=query_dict | |||||
| ) | ) | ||||
| data = res.json() | |||||
| return data["last_edited_time"] | |||||
| def get_database_last_edited_time(self, database_id: str) -> str: | |||||
| retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=database_id) | |||||
| query_dict: Dict[str, Any] = {} | |||||
| res = requests.request( | |||||
| "GET", retrieve_page_url, headers=self.headers, json=query_dict | |||||
| ) | |||||
| data = res.json() | data = res.json() | ||||
| return data["last_edited_time"] | return data["last_edited_time"] | ||||
| @classmethod | |||||
| def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: | |||||
| data_source_binding = DataSourceBinding.query.filter( | |||||
| db.and_( | |||||
| DataSourceBinding.tenant_id == tenant_id, | |||||
| DataSourceBinding.provider == 'notion', | |||||
| DataSourceBinding.disabled == False, | |||||
| DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"' | |||||
| ) | |||||
| ).first() | |||||
| if not data_source_binding: | |||||
| raise Exception(f'No notion data source binding found for tenant {tenant_id} ' | |||||
| f'and notion workspace {notion_workspace_id}') | |||||
| if __name__ == "__main__": | |||||
| reader = NotionPageReader() | |||||
| logger.info(reader.search("What I")) | |||||
| return data_source_binding.access_token |
| import logging | |||||
| from typing import List, Optional | |||||
| from langchain.document_loaders import PyPDFium2Loader | |||||
| from langchain.document_loaders.base import BaseLoader | |||||
| from langchain.schema import Document | |||||
| from extensions.ext_storage import storage | |||||
| from models.model import UploadFile | |||||
| logger = logging.getLogger(__name__) | |||||
| class PdfLoader(BaseLoader): | |||||
| """Load pdf files. | |||||
| Args: | |||||
| file_path: Path to the file to load. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| file_path: str, | |||||
| upload_file: Optional[UploadFile] = None | |||||
| ): | |||||
| """Initialize with file path.""" | |||||
| self._file_path = file_path | |||||
| self._upload_file = upload_file | |||||
| def load(self) -> List[Document]: | |||||
| plaintext_file_key = '' | |||||
| plaintext_file_exists = False | |||||
| if self._upload_file: | |||||
| if self._upload_file.hash: | |||||
| plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \ | |||||
| + self._upload_file.hash + '.0625.plaintext' | |||||
| try: | |||||
| text = storage.load(plaintext_file_key).decode('utf-8') | |||||
| plaintext_file_exists = True | |||||
| return [Document(page_content=text)] | |||||
| except FileNotFoundError: | |||||
| pass | |||||
| documents = PyPDFium2Loader(file_path=self._file_path).load() | |||||
| text_list = [] | |||||
| for document in documents: | |||||
| text_list.append(document.page_content) | |||||
| text = "\n\n".join(text_list) | |||||
| # save plaintext file for caching | |||||
| if not plaintext_file_exists and plaintext_file_key: | |||||
| storage.save(plaintext_file_key, text.encode('utf-8')) | |||||
| return documents | |||||
| from typing import Any, Dict, Optional, Sequence | from typing import Any, Dict, Optional, Sequence | ||||
| import tiktoken | |||||
| from llama_index.data_structs import Node | |||||
| from llama_index.docstore.types import BaseDocumentStore | |||||
| from llama_index.docstore.utils import json_to_doc | |||||
| from llama_index.schema import BaseDocument | |||||
| from langchain.schema import Document | |||||
| from sqlalchemy import func | from sqlalchemy import func | ||||
| from core.llm.token_calculator import TokenCalculator | from core.llm.token_calculator import TokenCalculator | ||||
| from models.dataset import Dataset, DocumentSegment | from models.dataset import Dataset, DocumentSegment | ||||
| class DatesetDocumentStore(BaseDocumentStore): | |||||
| class DatesetDocumentStore: | |||||
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| dataset: Dataset, | dataset: Dataset, | ||||
| return self._embedding_model_name | return self._embedding_model_name | ||||
| @property | @property | ||||
| def docs(self) -> Dict[str, BaseDocument]: | |||||
| def docs(self) -> Dict[str, Document]: | |||||
| document_segments = db.session.query(DocumentSegment).filter( | document_segments = db.session.query(DocumentSegment).filter( | ||||
| DocumentSegment.dataset_id == self._dataset.id | DocumentSegment.dataset_id == self._dataset.id | ||||
| ).all() | ).all() | ||||
| output = {} | output = {} | ||||
| for document_segment in document_segments: | for document_segment in document_segments: | ||||
| doc_id = document_segment.index_node_id | doc_id = document_segment.index_node_id | ||||
| result = self.segment_to_dict(document_segment) | |||||
| output[doc_id] = json_to_doc(result) | |||||
| output[doc_id] = Document( | |||||
| page_content=document_segment.content, | |||||
| metadata={ | |||||
| "doc_id": document_segment.index_node_id, | |||||
| "doc_hash": document_segment.index_node_hash, | |||||
| "document_id": document_segment.document_id, | |||||
| "dataset_id": document_segment.dataset_id, | |||||
| } | |||||
| ) | |||||
| return output | return output | ||||
| def add_documents( | def add_documents( | ||||
| self, docs: Sequence[BaseDocument], allow_update: bool = True | |||||
| self, docs: Sequence[Document], allow_update: bool = True | |||||
| ) -> None: | ) -> None: | ||||
| max_position = db.session.query(func.max(DocumentSegment.position)).filter( | max_position = db.session.query(func.max(DocumentSegment.position)).filter( | ||||
| DocumentSegment.document == self._document_id | DocumentSegment.document == self._document_id | ||||
| max_position = 0 | max_position = 0 | ||||
| for doc in docs: | for doc in docs: | ||||
| if doc.is_doc_id_none: | |||||
| raise ValueError("doc_id not set") | |||||
| if not isinstance(doc, Document): | |||||
| raise ValueError("doc must be a Document") | |||||
| if not isinstance(doc, Node): | |||||
| raise ValueError("doc must be a Node") | |||||
| segment_document = self.get_document(doc_id=doc.get_doc_id(), raise_error=False) | |||||
| segment_document = self.get_document(doc_id=doc.metadata['doc_id'], raise_error=False) | |||||
| # NOTE: doc could already exist in the store, but we overwrite it | # NOTE: doc could already exist in the store, but we overwrite it | ||||
| if not allow_update and segment_document: | if not allow_update and segment_document: | ||||
| raise ValueError( | raise ValueError( | ||||
| f"doc_id {doc.get_doc_id()} already exists. " | |||||
| f"doc_id {doc.metadata['doc_id']} already exists. " | |||||
| "Set allow_update to True to overwrite." | "Set allow_update to True to overwrite." | ||||
| ) | ) | ||||
| # calc embedding use tokens | # calc embedding use tokens | ||||
| tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.get_text()) | |||||
| tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.page_content) | |||||
| if not segment_document: | if not segment_document: | ||||
| max_position += 1 | max_position += 1 | ||||
| tenant_id=self._dataset.tenant_id, | tenant_id=self._dataset.tenant_id, | ||||
| dataset_id=self._dataset.id, | dataset_id=self._dataset.id, | ||||
| document_id=self._document_id, | document_id=self._document_id, | ||||
| index_node_id=doc.get_doc_id(), | |||||
| index_node_hash=doc.get_doc_hash(), | |||||
| index_node_id=doc.metadata['doc_id'], | |||||
| index_node_hash=doc.metadata['doc_hash'], | |||||
| position=max_position, | position=max_position, | ||||
| content=doc.get_text(), | |||||
| word_count=len(doc.get_text()), | |||||
| content=doc.page_content, | |||||
| word_count=len(doc.page_content), | |||||
| tokens=tokens, | tokens=tokens, | ||||
| created_by=self._user_id, | created_by=self._user_id, | ||||
| ) | ) | ||||
| db.session.add(segment_document) | db.session.add(segment_document) | ||||
| else: | else: | ||||
| segment_document.content = doc.get_text() | |||||
| segment_document.index_node_hash = doc.get_doc_hash() | |||||
| segment_document.word_count = len(doc.get_text()) | |||||
| segment_document.content = doc.page_content | |||||
| segment_document.index_node_hash = doc.metadata['doc_hash'] | |||||
| segment_document.word_count = len(doc.page_content) | |||||
| segment_document.tokens = tokens | segment_document.tokens = tokens | ||||
| db.session.commit() | db.session.commit() | ||||
| def get_document( | def get_document( | ||||
| self, doc_id: str, raise_error: bool = True | self, doc_id: str, raise_error: bool = True | ||||
| ) -> Optional[BaseDocument]: | |||||
| ) -> Optional[Document]: | |||||
| document_segment = self.get_document_segment(doc_id) | document_segment = self.get_document_segment(doc_id) | ||||
| if document_segment is None: | if document_segment is None: | ||||
| else: | else: | ||||
| return None | return None | ||||
| result = self.segment_to_dict(document_segment) | |||||
| return json_to_doc(result) | |||||
| return Document( | |||||
| page_content=document_segment.content, | |||||
| metadata={ | |||||
| "doc_id": document_segment.index_node_id, | |||||
| "doc_hash": document_segment.index_node_hash, | |||||
| "document_id": document_segment.document_id, | |||||
| "dataset_id": document_segment.dataset_id, | |||||
| } | |||||
| ) | |||||
| def delete_document(self, doc_id: str, raise_error: bool = True) -> None: | def delete_document(self, doc_id: str, raise_error: bool = True) -> None: | ||||
| document_segment = self.get_document_segment(doc_id) | document_segment = self.get_document_segment(doc_id) | ||||
| return document_segment.index_node_hash | return document_segment.index_node_hash | ||||
| def update_docstore(self, other: "BaseDocumentStore") -> None: | |||||
| """Update docstore. | |||||
| Args: | |||||
| other (BaseDocumentStore): docstore to update from | |||||
| """ | |||||
| self.add_documents(list(other.docs.values())) | |||||
| def get_document_segment(self, doc_id: str) -> DocumentSegment: | def get_document_segment(self, doc_id: str) -> DocumentSegment: | ||||
| document_segment = db.session.query(DocumentSegment).filter( | document_segment = db.session.query(DocumentSegment).filter( | ||||
| DocumentSegment.dataset_id == self._dataset.id, | DocumentSegment.dataset_id == self._dataset.id, | ||||
| ).first() | ).first() | ||||
| return document_segment | return document_segment | ||||
| def segment_to_dict(self, segment: DocumentSegment) -> Dict[str, Any]: | |||||
| return { | |||||
| "doc_id": segment.index_node_id, | |||||
| "doc_hash": segment.index_node_hash, | |||||
| "text": segment.content, | |||||
| "__type__": Node.get_type() | |||||
| } |
| from typing import Any, Dict, Optional, Sequence | |||||
| from llama_index.docstore.types import BaseDocumentStore | |||||
| from llama_index.schema import BaseDocument | |||||
| class EmptyDocumentStore(BaseDocumentStore): | |||||
| @classmethod | |||||
| def from_dict(cls, config_dict: Dict[str, Any]) -> "EmptyDocumentStore": | |||||
| return cls() | |||||
| def to_dict(self) -> Dict[str, Any]: | |||||
| """Serialize to dict.""" | |||||
| return {} | |||||
| @property | |||||
| def docs(self) -> Dict[str, BaseDocument]: | |||||
| return {} | |||||
| def add_documents( | |||||
| self, docs: Sequence[BaseDocument], allow_update: bool = True | |||||
| ) -> None: | |||||
| pass | |||||
| def document_exists(self, doc_id: str) -> bool: | |||||
| """Check if document exists.""" | |||||
| return False | |||||
| def get_document( | |||||
| self, doc_id: str, raise_error: bool = True | |||||
| ) -> Optional[BaseDocument]: | |||||
| return None | |||||
| def delete_document(self, doc_id: str, raise_error: bool = True) -> None: | |||||
| pass | |||||
| def set_document_hash(self, doc_id: str, doc_hash: str) -> None: | |||||
| """Set the hash for a given doc_id.""" | |||||
| pass | |||||
| def get_document_hash(self, doc_id: str) -> Optional[str]: | |||||
| """Get the stored hash for a document, if it exists.""" | |||||
| return None | |||||
| def update_docstore(self, other: "BaseDocumentStore") -> None: | |||||
| """Update docstore. | |||||
| Args: | |||||
| other (BaseDocumentStore): docstore to update from | |||||
| """ | |||||
| self.add_documents(list(other.docs.values())) |
| import logging | |||||
| from typing import List | |||||
| from langchain.embeddings.base import Embeddings | |||||
| from sqlalchemy.exc import IntegrityError | |||||
| from extensions.ext_database import db | |||||
| from libs import helper | |||||
| from models.dataset import Embedding | |||||
| class CacheEmbedding(Embeddings): | |||||
| def __init__(self, embeddings: Embeddings): | |||||
| self._embeddings = embeddings | |||||
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |||||
| """Embed search docs.""" | |||||
| # use doc embedding cache or store if not exists | |||||
| text_embeddings = [] | |||||
| embedding_queue_texts = [] | |||||
| for text in texts: | |||||
| hash = helper.generate_text_hash(text) | |||||
| embedding = db.session.query(Embedding).filter_by(hash=hash).first() | |||||
| if embedding: | |||||
| text_embeddings.append(embedding.get_embedding()) | |||||
| else: | |||||
| embedding_queue_texts.append(text) | |||||
| embedding_results = self._embeddings.embed_documents(embedding_queue_texts) | |||||
| i = 0 | |||||
| for text in embedding_queue_texts: | |||||
| hash = helper.generate_text_hash(text) | |||||
| try: | |||||
| embedding = Embedding(hash=hash) | |||||
| embedding.set_embedding(embedding_results[i]) | |||||
| db.session.add(embedding) | |||||
| db.session.commit() | |||||
| except IntegrityError: | |||||
| db.session.rollback() | |||||
| continue | |||||
| except: | |||||
| logging.exception('Failed to add embedding to db') | |||||
| continue | |||||
| i += 1 | |||||
| text_embeddings.extend(embedding_results) | |||||
| return text_embeddings | |||||
| def embed_query(self, text: str) -> List[float]: | |||||
| """Embed query text.""" | |||||
| # use doc embedding cache or store if not exists | |||||
| hash = helper.generate_text_hash(text) | |||||
| embedding = db.session.query(Embedding).filter_by(hash=hash).first() | |||||
| if embedding: | |||||
| return embedding.get_embedding() | |||||
| embedding_results = self._embeddings.embed_query(text) | |||||
| try: | |||||
| embedding = Embedding(hash=hash) | |||||
| embedding.set_embedding(embedding_results) | |||||
| db.session.add(embedding) | |||||
| db.session.commit() | |||||
| except IntegrityError: | |||||
| db.session.rollback() | |||||
| except: | |||||
| logging.exception('Failed to add embedding to db') | |||||
| return embedding_results |
| from typing import Optional, Any, List | |||||
| import openai | |||||
| from llama_index.embeddings.base import BaseEmbedding | |||||
| from llama_index.embeddings.openai import OpenAIEmbeddingMode, OpenAIEmbeddingModelType, _QUERY_MODE_MODEL_DICT, \ | |||||
| _TEXT_MODE_MODEL_DICT | |||||
| from tenacity import wait_random_exponential, retry, stop_after_attempt | |||||
| from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async | |||||
| @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) | |||||
| def get_embedding( | |||||
| text: str, | |||||
| engine: Optional[str] = None, | |||||
| api_key: Optional[str] = None, | |||||
| **kwargs | |||||
| ) -> List[float]: | |||||
| """Get embedding. | |||||
| NOTE: Copied from OpenAI's embedding utils: | |||||
| https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py | |||||
| Copied here to avoid importing unnecessary dependencies | |||||
| like matplotlib, plotly, scipy, sklearn. | |||||
| """ | |||||
| text = text.replace("\n", " ") | |||||
| return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"] | |||||
| @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) | |||||
| async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[ | |||||
| float]: | |||||
| """Asynchronously get embedding. | |||||
| NOTE: Copied from OpenAI's embedding utils: | |||||
| https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py | |||||
| Copied here to avoid importing unnecessary dependencies | |||||
| like matplotlib, plotly, scipy, sklearn. | |||||
| """ | |||||
| # replace newlines, which can negatively affect performance. | |||||
| text = text.replace("\n", " ") | |||||
| return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][ | |||||
| "embedding" | |||||
| ] | |||||
| @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) | |||||
| def get_embeddings( | |||||
| list_of_text: List[str], | |||||
| engine: Optional[str] = None, | |||||
| api_key: Optional[str] = None, | |||||
| **kwargs | |||||
| ) -> List[List[float]]: | |||||
| """Get embeddings. | |||||
| NOTE: Copied from OpenAI's embedding utils: | |||||
| https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py | |||||
| Copied here to avoid importing unnecessary dependencies | |||||
| like matplotlib, plotly, scipy, sklearn. | |||||
| """ | |||||
| assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048." | |||||
| # replace newlines, which can negatively affect performance. | |||||
| list_of_text = [text.replace("\n", " ") for text in list_of_text] | |||||
| data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data | |||||
| data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. | |||||
| return [d["embedding"] for d in data] | |||||
| @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) | |||||
| async def aget_embeddings( | |||||
| list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs | |||||
| ) -> List[List[float]]: | |||||
| """Asynchronously get embeddings. | |||||
| NOTE: Copied from OpenAI's embedding utils: | |||||
| https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py | |||||
| Copied here to avoid importing unnecessary dependencies | |||||
| like matplotlib, plotly, scipy, sklearn. | |||||
| """ | |||||
| assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048." | |||||
| # replace newlines, which can negatively affect performance. | |||||
| list_of_text = [text.replace("\n", " ") for text in list_of_text] | |||||
| data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data | |||||
| data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. | |||||
| return [d["embedding"] for d in data] | |||||
| class OpenAIEmbedding(BaseEmbedding): | |||||
| def __init__( | |||||
| self, | |||||
| mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE, | |||||
| model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002, | |||||
| deployment_name: Optional[str] = None, | |||||
| openai_api_key: Optional[str] = None, | |||||
| **kwargs: Any, | |||||
| ) -> None: | |||||
| """Init params.""" | |||||
| new_kwargs = {} | |||||
| if 'embed_batch_size' in kwargs: | |||||
| new_kwargs['embed_batch_size'] = kwargs['embed_batch_size'] | |||||
| if 'tokenizer' in kwargs: | |||||
| new_kwargs['tokenizer'] = kwargs['tokenizer'] | |||||
| super().__init__(**new_kwargs) | |||||
| self.mode = OpenAIEmbeddingMode(mode) | |||||
| self.model = OpenAIEmbeddingModelType(model) | |||||
| self.deployment_name = deployment_name | |||||
| self.openai_api_key = openai_api_key | |||||
| self.openai_api_type = kwargs.get('openai_api_type') | |||||
| self.openai_api_version = kwargs.get('openai_api_version') | |||||
| self.openai_api_base = kwargs.get('openai_api_base') | |||||
| @handle_llm_exceptions | |||||
| def _get_query_embedding(self, query: str) -> List[float]: | |||||
| """Get query embedding.""" | |||||
| if self.deployment_name is not None: | |||||
| engine = self.deployment_name | |||||
| else: | |||||
| key = (self.mode, self.model) | |||||
| if key not in _QUERY_MODE_MODEL_DICT: | |||||
| raise ValueError(f"Invalid mode, model combination: {key}") | |||||
| engine = _QUERY_MODE_MODEL_DICT[key] | |||||
| return get_embedding(query, engine=engine, api_key=self.openai_api_key, | |||||
| api_type=self.openai_api_type, api_version=self.openai_api_version, | |||||
| api_base=self.openai_api_base) | |||||
| def _get_text_embedding(self, text: str) -> List[float]: | |||||
| """Get text embedding.""" | |||||
| if self.deployment_name is not None: | |||||
| engine = self.deployment_name | |||||
| else: | |||||
| key = (self.mode, self.model) | |||||
| if key not in _TEXT_MODE_MODEL_DICT: | |||||
| raise ValueError(f"Invalid mode, model combination: {key}") | |||||
| engine = _TEXT_MODE_MODEL_DICT[key] | |||||
| return get_embedding(text, engine=engine, api_key=self.openai_api_key, | |||||
| api_type=self.openai_api_type, api_version=self.openai_api_version, | |||||
| api_base=self.openai_api_base) | |||||
| async def _aget_text_embedding(self, text: str) -> List[float]: | |||||
| """Asynchronously get text embedding.""" | |||||
| if self.deployment_name is not None: | |||||
| engine = self.deployment_name | |||||
| else: | |||||
| key = (self.mode, self.model) | |||||
| if key not in _TEXT_MODE_MODEL_DICT: | |||||
| raise ValueError(f"Invalid mode, model combination: {key}") | |||||
| engine = _TEXT_MODE_MODEL_DICT[key] | |||||
| return await aget_embedding(text, engine=engine, api_key=self.openai_api_key, | |||||
| api_type=self.openai_api_type, api_version=self.openai_api_version, | |||||
| api_base=self.openai_api_base) | |||||
| def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: | |||||
| """Get text embeddings. | |||||
| By default, this is a wrapper around _get_text_embedding. | |||||
| Can be overriden for batch queries. | |||||
| """ | |||||
| if self.openai_api_type and self.openai_api_type == 'azure': | |||||
| embeddings = [] | |||||
| for text in texts: | |||||
| embeddings.append(self._get_text_embedding(text)) | |||||
| return embeddings | |||||
| if self.deployment_name is not None: | |||||
| engine = self.deployment_name | |||||
| else: | |||||
| key = (self.mode, self.model) | |||||
| if key not in _TEXT_MODE_MODEL_DICT: | |||||
| raise ValueError(f"Invalid mode, model combination: {key}") | |||||
| engine = _TEXT_MODE_MODEL_DICT[key] | |||||
| embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key, | |||||
| api_type=self.openai_api_type, api_version=self.openai_api_version, | |||||
| api_base=self.openai_api_base) | |||||
| return embeddings | |||||
| async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: | |||||
| """Asynchronously get text embeddings.""" | |||||
| if self.openai_api_type and self.openai_api_type == 'azure': | |||||
| embeddings = [] | |||||
| for text in texts: | |||||
| embeddings.append(await self._aget_text_embedding(text)) | |||||
| return embeddings | |||||
| if self.deployment_name is not None: | |||||
| engine = self.deployment_name | |||||
| else: | |||||
| key = (self.mode, self.model) | |||||
| if key not in _TEXT_MODE_MODEL_DICT: | |||||
| raise ValueError(f"Invalid mode, model combination: {key}") | |||||
| engine = _TEXT_MODE_MODEL_DICT[key] | |||||
| embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key, | |||||
| api_type=self.openai_api_type, api_version=self.openai_api_version, | |||||
| api_base=self.openai_api_base) | |||||
| return embeddings |
| from __future__ import annotations | |||||
| from abc import abstractmethod, ABC | |||||
| from typing import List, Any | |||||
| from langchain.schema import Document, BaseRetriever | |||||
| from models.dataset import Dataset | |||||
| class BaseIndex(ABC): | |||||
| def __init__(self, dataset: Dataset): | |||||
| self.dataset = dataset | |||||
| @abstractmethod | |||||
| def create(self, texts: list[Document], **kwargs) -> BaseIndex: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def add_texts(self, texts: list[Document], **kwargs): | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def text_exists(self, id: str) -> bool: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def delete_by_ids(self, ids: list[str]) -> None: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def delete_by_document_id(self, document_id: str): | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def get_retriever(self, **kwargs: Any) -> BaseRetriever: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def search( | |||||
| self, query: str, | |||||
| **kwargs: Any | |||||
| ) -> List[Document]: | |||||
| raise NotImplementedError | |||||
| def delete(self) -> None: | |||||
| raise NotImplementedError | |||||
| def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: | |||||
| for text in texts: | |||||
| doc_id = text.metadata['doc_id'] | |||||
| exists_duplicate_node = self.text_exists(doc_id) | |||||
| if exists_duplicate_node: | |||||
| texts.remove(text) | |||||
| return texts | |||||
| def _get_uuids(self, texts: list[Document]) -> list[str]: | |||||
| return [text.metadata['doc_id'] for text in texts] |
| from flask import current_app | |||||
| from langchain.embeddings import OpenAIEmbeddings | |||||
| from core.embedding.cached_embedding import CacheEmbedding | |||||
| from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig | |||||
| from core.index.vector_index.vector_index import VectorIndex | |||||
| from core.llm.llm_builder import LLMBuilder | |||||
| from models.dataset import Dataset | |||||
| class IndexBuilder: | |||||
| @classmethod | |||||
| def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False): | |||||
| if indexing_technique == "high_quality": | |||||
| if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality': | |||||
| return None | |||||
| model_credentials = LLMBuilder.get_model_credentials( | |||||
| tenant_id=dataset.tenant_id, | |||||
| model_provider=LLMBuilder.get_default_provider(dataset.tenant_id), | |||||
| model_name='text-embedding-ada-002' | |||||
| ) | |||||
| embeddings = CacheEmbedding(OpenAIEmbeddings( | |||||
| **model_credentials | |||||
| )) | |||||
| return VectorIndex( | |||||
| dataset=dataset, | |||||
| config=current_app.config, | |||||
| embeddings=embeddings | |||||
| ) | |||||
| elif indexing_technique == "economy": | |||||
| return KeywordTableIndex( | |||||
| dataset=dataset, | |||||
| config=KeywordTableConfig( | |||||
| max_keywords_per_chunk=10 | |||||
| ) | |||||
| ) | |||||
| else: | |||||
| raise ValueError('Unknown indexing technique') |
| from langchain.callbacks import CallbackManager | |||||
| from llama_index import ServiceContext, PromptHelper, LLMPredictor | |||||
| from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler | |||||
| from core.embedding.openai_embedding import OpenAIEmbedding | |||||
| from core.llm.llm_builder import LLMBuilder | |||||
| class IndexBuilder: | |||||
| @classmethod | |||||
| def get_default_service_context(cls, tenant_id: str) -> ServiceContext: | |||||
| # set number of output tokens | |||||
| num_output = 512 | |||||
| # only for verbose | |||||
| callback_manager = CallbackManager([DifyStdOutCallbackHandler()]) | |||||
| llm = LLMBuilder.to_llm( | |||||
| tenant_id=tenant_id, | |||||
| model_name='text-davinci-003', | |||||
| temperature=0, | |||||
| max_tokens=num_output, | |||||
| callback_manager=callback_manager, | |||||
| ) | |||||
| llm_predictor = LLMPredictor(llm=llm) | |||||
| # These parameters here will affect the logic of segmenting the final synthesized response. | |||||
| # The number of refinement iterations in the synthesis process depends | |||||
| # on whether the length of the segmented output exceeds the max_input_size. | |||||
| prompt_helper = PromptHelper( | |||||
| max_input_size=3500, | |||||
| num_output=num_output, | |||||
| max_chunk_overlap=20 | |||||
| ) | |||||
| provider = LLMBuilder.get_default_provider(tenant_id) | |||||
| model_credentials = LLMBuilder.get_model_credentials( | |||||
| tenant_id=tenant_id, | |||||
| model_provider=provider, | |||||
| model_name='text-embedding-ada-002' | |||||
| ) | |||||
| return ServiceContext.from_defaults( | |||||
| llm_predictor=llm_predictor, | |||||
| prompt_helper=prompt_helper, | |||||
| embed_model=OpenAIEmbedding(**model_credentials), | |||||
| ) | |||||
| @classmethod | |||||
| def get_fake_llm_service_context(cls, tenant_id: str) -> ServiceContext: | |||||
| llm = LLMBuilder.to_llm( | |||||
| tenant_id=tenant_id, | |||||
| model_name='fake' | |||||
| ) | |||||
| return ServiceContext.from_defaults( | |||||
| llm_predictor=LLMPredictor(llm=llm), | |||||
| embed_model=OpenAIEmbedding() | |||||
| ) |
| import re | |||||
| from typing import ( | |||||
| Any, | |||||
| Dict, | |||||
| List, | |||||
| Set, | |||||
| Optional | |||||
| ) | |||||
| import jieba.analyse | |||||
| from core.index.keyword_table.stopwords import STOPWORDS | |||||
| from llama_index.indices.query.base import IS | |||||
| from llama_index import QueryMode | |||||
| from llama_index.indices.base import QueryMap | |||||
| from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex | |||||
| from llama_index.indices.keyword_table.query import BaseGPTKeywordTableQuery | |||||
| from llama_index.docstore import BaseDocumentStore | |||||
| from llama_index.indices.postprocessor.node import ( | |||||
| BaseNodePostprocessor, | |||||
| ) | |||||
| from llama_index.indices.response.response_builder import ResponseMode | |||||
| from llama_index.indices.service_context import ServiceContext | |||||
| from llama_index.optimization.optimizer import BaseTokenUsageOptimizer | |||||
| from llama_index.prompts.prompts import ( | |||||
| QuestionAnswerPrompt, | |||||
| RefinePrompt, | |||||
| SimpleInputPrompt, | |||||
| ) | |||||
| from core.index.query.synthesizer import EnhanceResponseSynthesizer | |||||
| def jieba_extract_keywords( | |||||
| text_chunk: str, | |||||
| max_keywords: Optional[int] = None, | |||||
| expand_with_subtokens: bool = True, | |||||
| ) -> Set[str]: | |||||
| """Extract keywords with JIEBA tfidf.""" | |||||
| keywords = jieba.analyse.extract_tags( | |||||
| sentence=text_chunk, | |||||
| topK=max_keywords, | |||||
| ) | |||||
| if expand_with_subtokens: | |||||
| return set(expand_tokens_with_subtokens(keywords)) | |||||
| else: | |||||
| return set(keywords) | |||||
| def expand_tokens_with_subtokens(tokens: Set[str]) -> Set[str]: | |||||
| """Get subtokens from a list of tokens., filtering for stopwords.""" | |||||
| results = set() | |||||
| for token in tokens: | |||||
| results.add(token) | |||||
| sub_tokens = re.findall(r"\w+", token) | |||||
| if len(sub_tokens) > 1: | |||||
| results.update({w for w in sub_tokens if w not in list(STOPWORDS)}) | |||||
| return results | |||||
| class GPTJIEBAKeywordTableIndex(BaseGPTKeywordTableIndex): | |||||
| """GPT JIEBA Keyword Table Index. | |||||
| This index uses a JIEBA keyword extractor to extract keywords from the text. | |||||
| """ | |||||
| def _extract_keywords(self, text: str) -> Set[str]: | |||||
| """Extract keywords from text.""" | |||||
| return jieba_extract_keywords(text, max_keywords=self.max_keywords_per_chunk) | |||||
| @classmethod | |||||
| def get_query_map(self) -> QueryMap: | |||||
| """Get query map.""" | |||||
| super_map = super().get_query_map() | |||||
| super_map[QueryMode.DEFAULT] = GPTKeywordTableJIEBAQuery | |||||
| return super_map | |||||
| def _delete(self, doc_id: str, **delete_kwargs: Any) -> None: | |||||
| """Delete a document.""" | |||||
| # get set of ids that correspond to node | |||||
| node_idxs_to_delete = {doc_id} | |||||
| # delete node_idxs from keyword to node idxs mapping | |||||
| keywords_to_delete = set() | |||||
| for keyword, node_idxs in self._index_struct.table.items(): | |||||
| if node_idxs_to_delete.intersection(node_idxs): | |||||
| self._index_struct.table[keyword] = node_idxs.difference( | |||||
| node_idxs_to_delete | |||||
| ) | |||||
| if not self._index_struct.table[keyword]: | |||||
| keywords_to_delete.add(keyword) | |||||
| for keyword in keywords_to_delete: | |||||
| del self._index_struct.table[keyword] | |||||
| class GPTKeywordTableJIEBAQuery(BaseGPTKeywordTableQuery): | |||||
| """GPT Keyword Table Index JIEBA Query. | |||||
| Extracts keywords using JIEBA keyword extractor. | |||||
| Set when `mode="jieba"` in `query` method of `GPTKeywordTableIndex`. | |||||
| .. code-block:: python | |||||
| response = index.query("<query_str>", mode="jieba") | |||||
| See BaseGPTKeywordTableQuery for arguments. | |||||
| """ | |||||
| @classmethod | |||||
| def from_args( | |||||
| cls, | |||||
| index_struct: IS, | |||||
| service_context: ServiceContext, | |||||
| docstore: Optional[BaseDocumentStore] = None, | |||||
| node_postprocessors: Optional[List[BaseNodePostprocessor]] = None, | |||||
| verbose: bool = False, | |||||
| # response synthesizer args | |||||
| response_mode: ResponseMode = ResponseMode.DEFAULT, | |||||
| text_qa_template: Optional[QuestionAnswerPrompt] = None, | |||||
| refine_template: Optional[RefinePrompt] = None, | |||||
| simple_template: Optional[SimpleInputPrompt] = None, | |||||
| response_kwargs: Optional[Dict] = None, | |||||
| use_async: bool = False, | |||||
| streaming: bool = False, | |||||
| optimizer: Optional[BaseTokenUsageOptimizer] = None, | |||||
| # class-specific args | |||||
| **kwargs: Any, | |||||
| ) -> "BaseGPTIndexQuery": | |||||
| response_synthesizer = EnhanceResponseSynthesizer.from_args( | |||||
| service_context=service_context, | |||||
| text_qa_template=text_qa_template, | |||||
| refine_template=refine_template, | |||||
| simple_template=simple_template, | |||||
| response_mode=response_mode, | |||||
| response_kwargs=response_kwargs, | |||||
| use_async=use_async, | |||||
| streaming=streaming, | |||||
| optimizer=optimizer, | |||||
| ) | |||||
| return cls( | |||||
| index_struct=index_struct, | |||||
| service_context=service_context, | |||||
| response_synthesizer=response_synthesizer, | |||||
| docstore=docstore, | |||||
| node_postprocessors=node_postprocessors, | |||||
| verbose=verbose, | |||||
| **kwargs, | |||||
| ) | |||||
| def _get_keywords(self, query_str: str) -> List[str]: | |||||
| """Extract keywords.""" | |||||
| return list( | |||||
| jieba_extract_keywords(query_str, max_keywords=self.max_keywords_per_query) | |||||
| ) |
| import json | |||||
| from typing import List, Optional | |||||
| from llama_index import ServiceContext, LLMPredictor, OpenAIEmbedding | |||||
| from llama_index.data_structs import KeywordTable, Node | |||||
| from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex | |||||
| from llama_index.indices.registry import load_index_struct_from_dict | |||||
| from core.docstore.dataset_docstore import DatesetDocumentStore | |||||
| from core.docstore.empty_docstore import EmptyDocumentStore | |||||
| from core.index.index_builder import IndexBuilder | |||||
| from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex | |||||
| from core.llm.llm_builder import LLMBuilder | |||||
| from extensions.ext_database import db | |||||
| from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment | |||||
| class KeywordTableIndex: | |||||
| def __init__(self, dataset: Dataset): | |||||
| self._dataset = dataset | |||||
| def add_nodes(self, nodes: List[Node]): | |||||
| llm = LLMBuilder.to_llm( | |||||
| tenant_id=self._dataset.tenant_id, | |||||
| model_name='fake' | |||||
| ) | |||||
| service_context = ServiceContext.from_defaults( | |||||
| llm_predictor=LLMPredictor(llm=llm), | |||||
| embed_model=OpenAIEmbedding() | |||||
| ) | |||||
| dataset_keyword_table = self.get_keyword_table() | |||||
| if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict: | |||||
| index_struct = KeywordTable() | |||||
| else: | |||||
| index_struct_dict = dataset_keyword_table.keyword_table_dict | |||||
| index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict) | |||||
| # create index | |||||
| index = GPTJIEBAKeywordTableIndex( | |||||
| index_struct=index_struct, | |||||
| docstore=EmptyDocumentStore(), | |||||
| service_context=service_context | |||||
| ) | |||||
| for node in nodes: | |||||
| keywords = index._extract_keywords(node.get_text()) | |||||
| self.update_segment_keywords(node.doc_id, list(keywords)) | |||||
| index._index_struct.add_node(list(keywords), node) | |||||
| index_struct_dict = index.index_struct.to_dict() | |||||
| if not dataset_keyword_table: | |||||
| dataset_keyword_table = DatasetKeywordTable( | |||||
| dataset_id=self._dataset.id, | |||||
| keyword_table=json.dumps(index_struct_dict) | |||||
| ) | |||||
| db.session.add(dataset_keyword_table) | |||||
| else: | |||||
| dataset_keyword_table.keyword_table = json.dumps(index_struct_dict) | |||||
| db.session.commit() | |||||
| def del_nodes(self, node_ids: List[str]): | |||||
| llm = LLMBuilder.to_llm( | |||||
| tenant_id=self._dataset.tenant_id, | |||||
| model_name='fake' | |||||
| ) | |||||
| service_context = ServiceContext.from_defaults( | |||||
| llm_predictor=LLMPredictor(llm=llm), | |||||
| embed_model=OpenAIEmbedding() | |||||
| ) | |||||
| dataset_keyword_table = self.get_keyword_table() | |||||
| if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict: | |||||
| return | |||||
| else: | |||||
| index_struct_dict = dataset_keyword_table.keyword_table_dict | |||||
| index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict) | |||||
| # create index | |||||
| index = GPTJIEBAKeywordTableIndex( | |||||
| index_struct=index_struct, | |||||
| docstore=EmptyDocumentStore(), | |||||
| service_context=service_context | |||||
| ) | |||||
| for node_id in node_ids: | |||||
| index.delete(node_id) | |||||
| index_struct_dict = index.index_struct.to_dict() | |||||
| if not dataset_keyword_table: | |||||
| dataset_keyword_table = DatasetKeywordTable( | |||||
| dataset_id=self._dataset.id, | |||||
| keyword_table=json.dumps(index_struct_dict) | |||||
| ) | |||||
| db.session.add(dataset_keyword_table) | |||||
| else: | |||||
| dataset_keyword_table.keyword_table = json.dumps(index_struct_dict) | |||||
| db.session.commit() | |||||
| @property | |||||
| def query_index(self) -> Optional[BaseGPTKeywordTableIndex]: | |||||
| docstore = DatesetDocumentStore( | |||||
| dataset=self._dataset, | |||||
| user_id=self._dataset.created_by, | |||||
| embedding_model_name="text-embedding-ada-002" | |||||
| ) | |||||
| service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id) | |||||
| dataset_keyword_table = self.get_keyword_table() | |||||
| if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict: | |||||
| return None | |||||
| index_struct: KeywordTable = load_index_struct_from_dict(dataset_keyword_table.keyword_table_dict) | |||||
| return GPTJIEBAKeywordTableIndex(index_struct=index_struct, docstore=docstore, service_context=service_context) | |||||
| def get_keyword_table(self): | |||||
| dataset_keyword_table = self._dataset.dataset_keyword_table | |||||
| if dataset_keyword_table: | |||||
| return dataset_keyword_table | |||||
| return None | |||||
| def update_segment_keywords(self, node_id: str, keywords: List[str]): | |||||
| document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first() | |||||
| if document_segment: | |||||
| document_segment.keywords = keywords | |||||
| db.session.commit() |
| import re | |||||
| from typing import Set | |||||
| import jieba | |||||
| from jieba.analyse import default_tfidf | |||||
| from core.index.keyword_table_index.stopwords import STOPWORDS | |||||
| class JiebaKeywordTableHandler: | |||||
| def __init__(self): | |||||
| default_tfidf.stop_words = STOPWORDS | |||||
| def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> Set[str]: | |||||
| """Extract keywords with JIEBA tfidf.""" | |||||
| keywords = jieba.analyse.extract_tags( | |||||
| sentence=text, | |||||
| topK=max_keywords_per_chunk, | |||||
| ) | |||||
| return set(self._expand_tokens_with_subtokens(keywords)) | |||||
| def _expand_tokens_with_subtokens(self, tokens: Set[str]) -> Set[str]: | |||||
| """Get subtokens from a list of tokens., filtering for stopwords.""" | |||||
| results = set() | |||||
| for token in tokens: | |||||
| results.add(token) | |||||
| sub_tokens = re.findall(r"\w+", token) | |||||
| if len(sub_tokens) > 1: | |||||
| results.update({w for w in sub_tokens if w not in list(STOPWORDS)}) | |||||
| return results |
| import json | |||||
| from collections import defaultdict | |||||
| from typing import Any, List, Optional, Dict | |||||
| from langchain.schema import Document, BaseRetriever | |||||
| from pydantic import BaseModel, Field, Extra | |||||
| from core.index.base import BaseIndex | |||||
| from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler | |||||
| from extensions.ext_database import db | |||||
| from models.dataset import Dataset, DocumentSegment, DatasetKeywordTable | |||||
| class KeywordTableConfig(BaseModel): | |||||
| max_keywords_per_chunk: int = 10 | |||||
| class KeywordTableIndex(BaseIndex): | |||||
| def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()): | |||||
| super().__init__(dataset) | |||||
| self._config = config | |||||
| def create(self, texts: list[Document], **kwargs) -> BaseIndex: | |||||
| keyword_table_handler = JiebaKeywordTableHandler() | |||||
| keyword_table = {} | |||||
| for text in texts: | |||||
| keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) | |||||
| self._update_segment_keywords(text.metadata['doc_id'], list(keywords)) | |||||
| keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) | |||||
| dataset_keyword_table = DatasetKeywordTable( | |||||
| dataset_id=self.dataset.id, | |||||
| keyword_table=json.dumps({ | |||||
| '__type__': 'keyword_table', | |||||
| '__data__': { | |||||
| "index_id": self.dataset.id, | |||||
| "summary": None, | |||||
| "table": {} | |||||
| } | |||||
| }, cls=SetEncoder) | |||||
| ) | |||||
| db.session.add(dataset_keyword_table) | |||||
| db.session.commit() | |||||
| self._save_dataset_keyword_table(keyword_table) | |||||
| return self | |||||
| def add_texts(self, texts: list[Document], **kwargs): | |||||
| keyword_table_handler = JiebaKeywordTableHandler() | |||||
| keyword_table = self._get_dataset_keyword_table() | |||||
| for text in texts: | |||||
| keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) | |||||
| self._update_segment_keywords(text.metadata['doc_id'], list(keywords)) | |||||
| keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) | |||||
| self._save_dataset_keyword_table(keyword_table) | |||||
| def text_exists(self, id: str) -> bool: | |||||
| keyword_table = self._get_dataset_keyword_table() | |||||
| return id in set.union(*keyword_table.values()) | |||||
| def delete_by_ids(self, ids: list[str]) -> None: | |||||
| keyword_table = self._get_dataset_keyword_table() | |||||
| keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) | |||||
| self._save_dataset_keyword_table(keyword_table) | |||||
| def delete_by_document_id(self, document_id: str): | |||||
| # get segment ids by document_id | |||||
| segments = db.session.query(DocumentSegment).filter( | |||||
| DocumentSegment.dataset_id == self.dataset.id, | |||||
| DocumentSegment.document_id == document_id | |||||
| ).all() | |||||
| ids = [segment.id for segment in segments] | |||||
| keyword_table = self._get_dataset_keyword_table() | |||||
| keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) | |||||
| self._save_dataset_keyword_table(keyword_table) | |||||
| def get_retriever(self, **kwargs: Any) -> BaseRetriever: | |||||
| return KeywordTableRetriever(index=self, **kwargs) | |||||
| def search( | |||||
| self, query: str, | |||||
| **kwargs: Any | |||||
| ) -> List[Document]: | |||||
| keyword_table = self._get_dataset_keyword_table() | |||||
| search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {} | |||||
| k = search_kwargs.get('k') if search_kwargs.get('k') else 4 | |||||
| sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k) | |||||
| documents = [] | |||||
| for chunk_index in sorted_chunk_indices: | |||||
| segment = db.session.query(DocumentSegment).filter( | |||||
| DocumentSegment.dataset_id == self.dataset.id, | |||||
| DocumentSegment.index_node_id == chunk_index | |||||
| ).first() | |||||
| if segment: | |||||
| documents.append(Document( | |||||
| page_content=segment.content, | |||||
| metadata={ | |||||
| "doc_id": chunk_index, | |||||
| "document_id": segment.document_id, | |||||
| "dataset_id": segment.dataset_id, | |||||
| } | |||||
| )) | |||||
| return documents | |||||
| def delete(self) -> None: | |||||
| dataset_keyword_table = self.dataset.dataset_keyword_table | |||||
| if dataset_keyword_table: | |||||
| db.session.delete(dataset_keyword_table) | |||||
| db.session.commit() | |||||
| def _save_dataset_keyword_table(self, keyword_table): | |||||
| keyword_table_dict = { | |||||
| '__type__': 'keyword_table', | |||||
| '__data__': { | |||||
| "index_id": self.dataset.id, | |||||
| "summary": None, | |||||
| "table": keyword_table | |||||
| } | |||||
| } | |||||
| self.dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder) | |||||
| db.session.commit() | |||||
| def _get_dataset_keyword_table(self) -> Optional[dict]: | |||||
| dataset_keyword_table = self.dataset.dataset_keyword_table | |||||
| if dataset_keyword_table: | |||||
| if dataset_keyword_table.keyword_table_dict: | |||||
| return dataset_keyword_table.keyword_table_dict['__data__']['table'] | |||||
| else: | |||||
| dataset_keyword_table = DatasetKeywordTable( | |||||
| dataset_id=self.dataset.id, | |||||
| keyword_table=json.dumps({ | |||||
| '__type__': 'keyword_table', | |||||
| '__data__': { | |||||
| "index_id": self.dataset.id, | |||||
| "summary": None, | |||||
| "table": {} | |||||
| } | |||||
| }, cls=SetEncoder) | |||||
| ) | |||||
| db.session.add(dataset_keyword_table) | |||||
| db.session.commit() | |||||
| return {} | |||||
| def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict: | |||||
| for keyword in keywords: | |||||
| if keyword not in keyword_table: | |||||
| keyword_table[keyword] = set() | |||||
| keyword_table[keyword].add(id) | |||||
| return keyword_table | |||||
| def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]) -> dict: | |||||
| # get set of ids that correspond to node | |||||
| node_idxs_to_delete = set(ids) | |||||
| # delete node_idxs from keyword to node idxs mapping | |||||
| keywords_to_delete = set() | |||||
| for keyword, node_idxs in keyword_table.items(): | |||||
| if node_idxs_to_delete.intersection(node_idxs): | |||||
| keyword_table[keyword] = node_idxs.difference( | |||||
| node_idxs_to_delete | |||||
| ) | |||||
| if not keyword_table[keyword]: | |||||
| keywords_to_delete.add(keyword) | |||||
| for keyword in keywords_to_delete: | |||||
| del keyword_table[keyword] | |||||
| return keyword_table | |||||
| def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4): | |||||
| keyword_table_handler = JiebaKeywordTableHandler() | |||||
| keywords = keyword_table_handler.extract_keywords(query) | |||||
| # go through text chunks in order of most matching keywords | |||||
| chunk_indices_count: Dict[str, int] = defaultdict(int) | |||||
| keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())] | |||||
| for keyword in keywords: | |||||
| for node_id in keyword_table[keyword]: | |||||
| chunk_indices_count[node_id] += 1 | |||||
| sorted_chunk_indices = sorted( | |||||
| list(chunk_indices_count.keys()), | |||||
| key=lambda x: chunk_indices_count[x], | |||||
| reverse=True, | |||||
| ) | |||||
| return sorted_chunk_indices[: k] | |||||
| def _update_segment_keywords(self, node_id: str, keywords: List[str]): | |||||
| document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first() | |||||
| if document_segment: | |||||
| document_segment.keywords = keywords | |||||
| db.session.commit() | |||||
| class KeywordTableRetriever(BaseRetriever, BaseModel): | |||||
| index: KeywordTableIndex | |||||
| search_kwargs: dict = Field(default_factory=dict) | |||||
| class Config: | |||||
| """Configuration for this pydantic object.""" | |||||
| extra = Extra.forbid | |||||
| arbitrary_types_allowed = True | |||||
| def get_relevant_documents(self, query: str) -> List[Document]: | |||||
| """Get documents relevant for a query. | |||||
| Args: | |||||
| query: string to find relevant documents for | |||||
| Returns: | |||||
| List of relevant documents | |||||
| """ | |||||
| return self.index.search(query, **self.search_kwargs) | |||||
| async def aget_relevant_documents(self, query: str) -> List[Document]: | |||||
| raise NotImplementedError("KeywordTableRetriever does not support async") | |||||
| class SetEncoder(json.JSONEncoder): | |||||
| def default(self, obj): | |||||
| if isinstance(obj, set): | |||||
| return list(obj) | |||||
| return super().default(obj) |
| from typing import ( | |||||
| Any, | |||||
| Dict, | |||||
| Optional, Sequence, | |||||
| ) | |||||
| from llama_index.indices.response.response_synthesis import ResponseSynthesizer | |||||
| from llama_index.indices.response.response_builder import ResponseMode, BaseResponseBuilder, get_response_builder | |||||
| from llama_index.indices.service_context import ServiceContext | |||||
| from llama_index.optimization.optimizer import BaseTokenUsageOptimizer | |||||
| from llama_index.prompts.prompts import ( | |||||
| QuestionAnswerPrompt, | |||||
| RefinePrompt, | |||||
| SimpleInputPrompt, | |||||
| ) | |||||
| from llama_index.types import RESPONSE_TEXT_TYPE | |||||
| class EnhanceResponseSynthesizer(ResponseSynthesizer): | |||||
| @classmethod | |||||
| def from_args( | |||||
| cls, | |||||
| service_context: ServiceContext, | |||||
| streaming: bool = False, | |||||
| use_async: bool = False, | |||||
| text_qa_template: Optional[QuestionAnswerPrompt] = None, | |||||
| refine_template: Optional[RefinePrompt] = None, | |||||
| simple_template: Optional[SimpleInputPrompt] = None, | |||||
| response_mode: ResponseMode = ResponseMode.DEFAULT, | |||||
| response_kwargs: Optional[Dict] = None, | |||||
| optimizer: Optional[BaseTokenUsageOptimizer] = None, | |||||
| ) -> "ResponseSynthesizer": | |||||
| response_builder: Optional[BaseResponseBuilder] = None | |||||
| if response_mode != ResponseMode.NO_TEXT: | |||||
| if response_mode == 'no_synthesizer': | |||||
| response_builder = NoSynthesizer( | |||||
| service_context=service_context, | |||||
| simple_template=simple_template, | |||||
| streaming=streaming, | |||||
| ) | |||||
| else: | |||||
| response_builder = get_response_builder( | |||||
| service_context, | |||||
| text_qa_template, | |||||
| refine_template, | |||||
| simple_template, | |||||
| response_mode, | |||||
| use_async=use_async, | |||||
| streaming=streaming, | |||||
| ) | |||||
| return cls(response_builder, response_mode, response_kwargs, optimizer) | |||||
| class NoSynthesizer(BaseResponseBuilder): | |||||
| def __init__( | |||||
| self, | |||||
| service_context: ServiceContext, | |||||
| simple_template: Optional[SimpleInputPrompt] = None, | |||||
| streaming: bool = False, | |||||
| ) -> None: | |||||
| super().__init__(service_context, streaming) | |||||
| async def aget_response( | |||||
| self, | |||||
| query_str: str, | |||||
| text_chunks: Sequence[str], | |||||
| prev_response: Optional[str] = None, | |||||
| **response_kwargs: Any, | |||||
| ) -> RESPONSE_TEXT_TYPE: | |||||
| return "\n".join(text_chunks) | |||||
| def get_response( | |||||
| self, | |||||
| query_str: str, | |||||
| text_chunks: Sequence[str], | |||||
| prev_response: Optional[str] = None, | |||||
| **response_kwargs: Any, | |||||
| ) -> RESPONSE_TEXT_TYPE: | |||||
| return "\n".join(text_chunks) |
| from pathlib import Path | |||||
| from typing import Dict | |||||
| from bs4 import BeautifulSoup | |||||
| from llama_index.readers.file.base_parser import BaseParser | |||||
| class HTMLParser(BaseParser): | |||||
| """HTML parser.""" | |||||
| def _init_parser(self) -> Dict: | |||||
| """Init parser.""" | |||||
| return {} | |||||
| def parse_file(self, file: Path, errors: str = "ignore") -> str: | |||||
| """Parse file.""" | |||||
| with open(file, "rb") as fp: | |||||
| soup = BeautifulSoup(fp, 'html.parser') | |||||
| text = soup.get_text() | |||||
| text = text.strip() if text else '' | |||||
| return text |
| """Markdown parser. | |||||
| Contains parser for md files. | |||||
| """ | |||||
| import re | |||||
| from pathlib import Path | |||||
| from typing import Any, Dict, List, Optional, Tuple, Union, cast | |||||
| from llama_index.readers.file.base_parser import BaseParser | |||||
| class MarkdownParser(BaseParser): | |||||
| """Markdown parser. | |||||
| Extract text from markdown files. | |||||
| Returns dictionary with keys as headers and values as the text between headers. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| *args: Any, | |||||
| remove_hyperlinks: bool = True, | |||||
| remove_images: bool = True, | |||||
| **kwargs: Any, | |||||
| ) -> None: | |||||
| """Init params.""" | |||||
| super().__init__(*args, **kwargs) | |||||
| self._remove_hyperlinks = remove_hyperlinks | |||||
| self._remove_images = remove_images | |||||
| def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]: | |||||
| """Convert a markdown file to a dictionary. | |||||
| The keys are the headers and the values are the text under each header. | |||||
| """ | |||||
| markdown_tups: List[Tuple[Optional[str], str]] = [] | |||||
| lines = markdown_text.split("\n") | |||||
| current_header = None | |||||
| current_text = "" | |||||
| for line in lines: | |||||
| header_match = re.match(r"^#+\s", line) | |||||
| if header_match: | |||||
| if current_header is not None: | |||||
| markdown_tups.append((current_header, current_text)) | |||||
| current_header = line | |||||
| current_text = "" | |||||
| else: | |||||
| current_text += line + "\n" | |||||
| markdown_tups.append((current_header, current_text)) | |||||
| if current_header is not None: | |||||
| # pass linting, assert keys are defined | |||||
| markdown_tups = [ | |||||
| (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) | |||||
| for key, value in markdown_tups | |||||
| ] | |||||
| else: | |||||
| markdown_tups = [ | |||||
| (key, re.sub("\n", "", value)) for key, value in markdown_tups | |||||
| ] | |||||
| return markdown_tups | |||||
| def remove_images(self, content: str) -> str: | |||||
| """Get a dictionary of a markdown file from its path.""" | |||||
| pattern = r"!{1}\[\[(.*)\]\]" | |||||
| content = re.sub(pattern, "", content) | |||||
| return content | |||||
| def remove_hyperlinks(self, content: str) -> str: | |||||
| """Get a dictionary of a markdown file from its path.""" | |||||
| pattern = r"\[(.*?)\]\((.*?)\)" | |||||
| content = re.sub(pattern, r"\1", content) | |||||
| return content | |||||
| def _init_parser(self) -> Dict: | |||||
| """Initialize the parser with the config.""" | |||||
| return {} | |||||
| def parse_tups( | |||||
| self, filepath: Path, errors: str = "ignore" | |||||
| ) -> List[Tuple[Optional[str], str]]: | |||||
| """Parse file into tuples.""" | |||||
| with open(filepath, "r", encoding="utf-8") as f: | |||||
| content = f.read() | |||||
| if self._remove_hyperlinks: | |||||
| content = self.remove_hyperlinks(content) | |||||
| if self._remove_images: | |||||
| content = self.remove_images(content) | |||||
| markdown_tups = self.markdown_to_tups(content) | |||||
| return markdown_tups | |||||
| def parse_file( | |||||
| self, filepath: Path, errors: str = "ignore" | |||||
| ) -> Union[str, List[str]]: | |||||
| """Parse file into string.""" | |||||
| tups = self.parse_tups(filepath, errors=errors) | |||||
| results = [] | |||||
| # TODO: don't include headers right now | |||||
| for header, value in tups: | |||||
| if header is None: | |||||
| results.append(value) | |||||
| else: | |||||
| results.append(f"\n\n{header}\n{value}") | |||||
| return results |
| from pathlib import Path | |||||
| from typing import Dict | |||||
| from flask import current_app | |||||
| from llama_index.readers.file.base_parser import BaseParser | |||||
| from pypdf import PdfReader | |||||
| from extensions.ext_storage import storage | |||||
| from models.model import UploadFile | |||||
| class PDFParser(BaseParser): | |||||
| """PDF parser.""" | |||||
| def _init_parser(self) -> Dict: | |||||
| """Init parser.""" | |||||
| return {} | |||||
| def parse_file(self, file: Path, errors: str = "ignore") -> str: | |||||
| """Parse file.""" | |||||
| if not current_app.config.get('PDF_PREVIEW', True): | |||||
| return '' | |||||
| plaintext_file_key = '' | |||||
| plaintext_file_exists = False | |||||
| if self._parser_config and 'upload_file' in self._parser_config and self._parser_config['upload_file']: | |||||
| upload_file: UploadFile = self._parser_config['upload_file'] | |||||
| if upload_file.hash: | |||||
| plaintext_file_key = 'upload_files/' + upload_file.tenant_id + '/' + upload_file.hash + '.plaintext' | |||||
| try: | |||||
| text = storage.load(plaintext_file_key).decode('utf-8') | |||||
| plaintext_file_exists = True | |||||
| return text | |||||
| except FileNotFoundError: | |||||
| pass | |||||
| text_list = [] | |||||
| with open(file, "rb") as fp: | |||||
| # Create a PDF object | |||||
| pdf = PdfReader(fp) | |||||
| # Get the number of pages in the PDF document | |||||
| num_pages = len(pdf.pages) | |||||
| # Iterate over every page | |||||
| for page in range(num_pages): | |||||
| # Extract the text from the page | |||||
| page_text = pdf.pages[page].extract_text() | |||||
| text_list.append(page_text) | |||||
| text = "\n".join(text_list) | |||||
| # save plaintext file for caching | |||||
| if not plaintext_file_exists and plaintext_file_key: | |||||
| storage.save(plaintext_file_key, text.encode('utf-8')) | |||||
| return text |
| from pathlib import Path | |||||
| import json | |||||
| from typing import Dict | |||||
| from openpyxl import load_workbook | |||||
| from llama_index.readers.file.base_parser import BaseParser | |||||
| from flask import current_app | |||||
| class XLSXParser(BaseParser): | |||||
| """XLSX parser.""" | |||||
| def _init_parser(self) -> Dict: | |||||
| """Init parser""" | |||||
| return {} | |||||
| def parse_file(self, file: Path, errors: str = "ignore") -> str: | |||||
| data = [] | |||||
| keys = [] | |||||
| with open(file, "r") as fp: | |||||
| wb = load_workbook(filename=file, read_only=True) | |||||
| # loop over all sheets | |||||
| for sheet in wb: | |||||
| for row in sheet.iter_rows(values_only=True): | |||||
| if all(v is None for v in row): | |||||
| continue | |||||
| if keys == []: | |||||
| keys = list(map(str, row)) | |||||
| else: | |||||
| row_dict = dict(zip(keys, row)) | |||||
| row_dict = {k: v for k, v in row_dict.items() if v} | |||||
| data.append(json.dumps(row_dict, ensure_ascii=False)) | |||||
| return '\n\n'.join(data) |
| import json | |||||
| import logging | |||||
| from typing import List, Optional | |||||
| from llama_index.data_structs import Node | |||||
| from requests import ReadTimeout | |||||
| from sqlalchemy.exc import IntegrityError | |||||
| from tenacity import retry, stop_after_attempt, retry_if_exception_type | |||||
| from core.index.index_builder import IndexBuilder | |||||
| from core.vector_store.base import BaseGPTVectorStoreIndex | |||||
| from extensions.ext_vector_store import vector_store | |||||
| from extensions.ext_database import db | |||||
| from models.dataset import Dataset, Embedding | |||||
| class VectorIndex: | |||||
| def __init__(self, dataset: Dataset): | |||||
| self._dataset = dataset | |||||
| def add_nodes(self, nodes: List[Node], duplicate_check: bool = False): | |||||
| if not self._dataset.index_struct_dict: | |||||
| index_id = "Vector_index_" + self._dataset.id.replace("-", "_") | |||||
| self._dataset.index_struct = json.dumps(vector_store.to_index_struct(index_id)) | |||||
| db.session.commit() | |||||
| service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id) | |||||
| index = vector_store.get_index( | |||||
| service_context=service_context, | |||||
| index_struct=self._dataset.index_struct_dict | |||||
| ) | |||||
| if duplicate_check: | |||||
| nodes = self._filter_duplicate_nodes(index, nodes) | |||||
| embedding_queue_nodes = [] | |||||
| embedded_nodes = [] | |||||
| for node in nodes: | |||||
| node_hash = node.doc_hash | |||||
| # if node hash in cached embedding tables, use cached embedding | |||||
| embedding = db.session.query(Embedding).filter_by(hash=node_hash).first() | |||||
| if embedding: | |||||
| node.embedding = embedding.get_embedding() | |||||
| embedded_nodes.append(node) | |||||
| else: | |||||
| embedding_queue_nodes.append(node) | |||||
| if embedding_queue_nodes: | |||||
| embedding_results = index._get_node_embedding_results( | |||||
| embedding_queue_nodes, | |||||
| set(), | |||||
| ) | |||||
| # pre embed nodes for cached embedding | |||||
| for embedding_result in embedding_results: | |||||
| node = embedding_result.node | |||||
| node.embedding = embedding_result.embedding | |||||
| try: | |||||
| embedding = Embedding(hash=node.doc_hash) | |||||
| embedding.set_embedding(node.embedding) | |||||
| db.session.add(embedding) | |||||
| db.session.commit() | |||||
| except IntegrityError: | |||||
| db.session.rollback() | |||||
| continue | |||||
| except: | |||||
| logging.exception('Failed to add embedding to db') | |||||
| continue | |||||
| embedded_nodes.append(node) | |||||
| self.index_insert_nodes(index, embedded_nodes) | |||||
| @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3)) | |||||
| def index_insert_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]): | |||||
| index.insert_nodes(nodes) | |||||
| def del_nodes(self, node_ids: List[str]): | |||||
| if not self._dataset.index_struct_dict: | |||||
| return | |||||
| service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id) | |||||
| index = vector_store.get_index( | |||||
| service_context=service_context, | |||||
| index_struct=self._dataset.index_struct_dict | |||||
| ) | |||||
| for node_id in node_ids: | |||||
| self.index_delete_node(index, node_id) | |||||
| @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3)) | |||||
| def index_delete_node(self, index: BaseGPTVectorStoreIndex, node_id: str): | |||||
| index.delete_node(node_id) | |||||
| def del_doc(self, doc_id: str): | |||||
| if not self._dataset.index_struct_dict: | |||||
| return | |||||
| service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id) | |||||
| index = vector_store.get_index( | |||||
| service_context=service_context, | |||||
| index_struct=self._dataset.index_struct_dict | |||||
| ) | |||||
| self.index_delete_doc(index, doc_id) | |||||
| @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3)) | |||||
| def index_delete_doc(self, index: BaseGPTVectorStoreIndex, doc_id: str): | |||||
| index.delete(doc_id) | |||||
| @property | |||||
| def query_index(self) -> Optional[BaseGPTVectorStoreIndex]: | |||||
| if not self._dataset.index_struct_dict: | |||||
| return None | |||||
| service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id) | |||||
| return vector_store.get_index( | |||||
| service_context=service_context, | |||||
| index_struct=self._dataset.index_struct_dict | |||||
| ) | |||||
| def _filter_duplicate_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]) -> List[Node]: | |||||
| for node in nodes: | |||||
| node_id = node.doc_id | |||||
| exists_duplicate_node = index.exists_by_node_id(node_id) | |||||
| if exists_duplicate_node: | |||||
| nodes.remove(node) | |||||
| return nodes |
| import json | |||||
| import logging | |||||
| from abc import abstractmethod | |||||
| from typing import List, Any, cast | |||||
| from langchain.embeddings.base import Embeddings | |||||
| from langchain.schema import Document, BaseRetriever | |||||
| from langchain.vectorstores import VectorStore | |||||
| from weaviate import UnexpectedStatusCodeException | |||||
| from core.index.base import BaseIndex | |||||
| from extensions.ext_database import db | |||||
| from models.dataset import Dataset, DocumentSegment | |||||
| from models.dataset import Document as DatasetDocument | |||||
| class BaseVectorIndex(BaseIndex): | |||||
| def __init__(self, dataset: Dataset, embeddings: Embeddings): | |||||
| super().__init__(dataset) | |||||
| self._embeddings = embeddings | |||||
| self._vector_store = None | |||||
| def get_type(self) -> str: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def get_index_name(self, dataset: Dataset) -> str: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def to_index_struct(self) -> dict: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def _get_vector_store(self) -> VectorStore: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def _get_vector_store_class(self) -> type: | |||||
| raise NotImplementedError | |||||
| def search( | |||||
| self, query: str, | |||||
| **kwargs: Any | |||||
| ) -> List[Document]: | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity' | |||||
| search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {} | |||||
| if search_type == 'similarity_score_threshold': | |||||
| score_threshold = search_kwargs.get("score_threshold") | |||||
| if (score_threshold is None) or (not isinstance(score_threshold, float)): | |||||
| search_kwargs['score_threshold'] = .0 | |||||
| docs_with_similarity = vector_store.similarity_search_with_relevance_scores( | |||||
| query, **search_kwargs | |||||
| ) | |||||
| docs = [] | |||||
| for doc, similarity in docs_with_similarity: | |||||
| doc.metadata['score'] = similarity | |||||
| docs.append(doc) | |||||
| return docs | |||||
| # similarity k | |||||
| # mmr k, fetch_k, lambda_mult | |||||
| # similarity_score_threshold k | |||||
| return vector_store.as_retriever( | |||||
| search_type=search_type, | |||||
| search_kwargs=search_kwargs | |||||
| ).get_relevant_documents(query) | |||||
| def get_retriever(self, **kwargs: Any) -> BaseRetriever: | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| return vector_store.as_retriever(**kwargs) | |||||
| def add_texts(self, texts: list[Document], **kwargs): | |||||
| if self._is_origin(): | |||||
| self.recreate_dataset(self.dataset) | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| if kwargs.get('duplicate_check', False): | |||||
| texts = self._filter_duplicate_texts(texts) | |||||
| uuids = self._get_uuids(texts) | |||||
| vector_store.add_documents(texts, uuids=uuids) | |||||
| def text_exists(self, id: str) -> bool: | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| return vector_store.text_exists(id) | |||||
| def delete_by_ids(self, ids: list[str]) -> None: | |||||
| if self._is_origin(): | |||||
| self.recreate_dataset(self.dataset) | |||||
| return | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| for node_id in ids: | |||||
| vector_store.del_text(node_id) | |||||
| def delete(self) -> None: | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| vector_store.delete() | |||||
| def _is_origin(self): | |||||
| return False | |||||
| def recreate_dataset(self, dataset: Dataset): | |||||
| logging.info(f"Recreating dataset {dataset.id}") | |||||
| try: | |||||
| self.delete() | |||||
| except UnexpectedStatusCodeException as e: | |||||
| if e.status_code != 400: | |||||
| # 400 means index not exists | |||||
| raise e | |||||
| dataset_documents = db.session.query(DatasetDocument).filter( | |||||
| DatasetDocument.dataset_id == dataset.id, | |||||
| DatasetDocument.indexing_status == 'completed', | |||||
| DatasetDocument.enabled == True, | |||||
| DatasetDocument.archived == False, | |||||
| ).all() | |||||
| documents = [] | |||||
| for dataset_document in dataset_documents: | |||||
| segments = db.session.query(DocumentSegment).filter( | |||||
| DocumentSegment.document_id == dataset_document.id, | |||||
| DocumentSegment.status == 'completed', | |||||
| DocumentSegment.enabled == True | |||||
| ).all() | |||||
| for segment in segments: | |||||
| document = Document( | |||||
| page_content=segment.content, | |||||
| metadata={ | |||||
| "doc_id": segment.index_node_id, | |||||
| "doc_hash": segment.index_node_hash, | |||||
| "document_id": segment.document_id, | |||||
| "dataset_id": segment.dataset_id, | |||||
| } | |||||
| ) | |||||
| documents.append(document) | |||||
| origin_index_struct = self.dataset.index_struct | |||||
| self.dataset.index_struct = None | |||||
| if documents: | |||||
| try: | |||||
| self.create(documents) | |||||
| except Exception as e: | |||||
| self.dataset.index_struct = origin_index_struct | |||||
| raise e | |||||
| dataset.index_struct = json.dumps(self.to_index_struct()) | |||||
| db.session.commit() | |||||
| self.dataset = dataset | |||||
| logging.info(f"Dataset {dataset.id} recreate successfully.") |
| import os | |||||
| from typing import Optional, Any, List, cast | |||||
| import qdrant_client | |||||
| from langchain.embeddings.base import Embeddings | |||||
| from langchain.schema import Document, BaseRetriever | |||||
| from langchain.vectorstores import VectorStore | |||||
| from pydantic import BaseModel | |||||
| from core.index.base import BaseIndex | |||||
| from core.index.vector_index.base import BaseVectorIndex | |||||
| from core.vector_store.qdrant_vector_store import QdrantVectorStore | |||||
| from models.dataset import Dataset | |||||
| class QdrantConfig(BaseModel): | |||||
| endpoint: str | |||||
| api_key: Optional[str] | |||||
| root_path: Optional[str] | |||||
| def to_qdrant_params(self): | |||||
| if self.endpoint and self.endpoint.startswith('path:'): | |||||
| path = self.endpoint.replace('path:', '') | |||||
| if not os.path.isabs(path): | |||||
| path = os.path.join(self.root_path, path) | |||||
| return { | |||||
| 'path': path | |||||
| } | |||||
| else: | |||||
| return { | |||||
| 'url': self.endpoint, | |||||
| 'api_key': self.api_key, | |||||
| } | |||||
| class QdrantVectorIndex(BaseVectorIndex): | |||||
| def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings): | |||||
| super().__init__(dataset, embeddings) | |||||
| self._client_config = config | |||||
| def get_type(self) -> str: | |||||
| return 'qdrant' | |||||
| def get_index_name(self, dataset: Dataset) -> str: | |||||
| if self.dataset.index_struct_dict: | |||||
| return self.dataset.index_struct_dict['vector_store']['collection_name'] | |||||
| dataset_id = dataset.id | |||||
| return "Index_" + dataset_id.replace("-", "_") | |||||
| def to_index_struct(self) -> dict: | |||||
| return { | |||||
| "type": self.get_type(), | |||||
| "vector_store": {"collection_name": self.get_index_name(self.dataset)} | |||||
| } | |||||
| def create(self, texts: list[Document], **kwargs) -> BaseIndex: | |||||
| uuids = self._get_uuids(texts) | |||||
| self._vector_store = QdrantVectorStore.from_documents( | |||||
| texts, | |||||
| self._embeddings, | |||||
| collection_name=self.get_index_name(self.dataset), | |||||
| ids=uuids, | |||||
| content_payload_key='text', | |||||
| **self._client_config.to_qdrant_params() | |||||
| ) | |||||
| return self | |||||
| def _get_vector_store(self) -> VectorStore: | |||||
| """Only for created index.""" | |||||
| if self._vector_store: | |||||
| return self._vector_store | |||||
| client = qdrant_client.QdrantClient( | |||||
| **self._client_config.to_qdrant_params() | |||||
| ) | |||||
| return QdrantVectorStore( | |||||
| client=client, | |||||
| collection_name=self.get_index_name(self.dataset), | |||||
| embeddings=self._embeddings, | |||||
| content_payload_key='text' | |||||
| ) | |||||
| def _get_vector_store_class(self) -> type: | |||||
| return QdrantVectorStore | |||||
| def delete_by_document_id(self, document_id: str): | |||||
| if self._is_origin(): | |||||
| self.recreate_dataset(self.dataset) | |||||
| return | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| from qdrant_client.http import models | |||||
| vector_store.del_texts(models.Filter( | |||||
| must=[ | |||||
| models.FieldCondition( | |||||
| key="metadata.document_id", | |||||
| match=models.MatchValue(value=document_id), | |||||
| ), | |||||
| ], | |||||
| )) | |||||
| def _is_origin(self): | |||||
| if self.dataset.index_struct_dict: | |||||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['collection_name'] | |||||
| if class_prefix.startswith('Vector_'): | |||||
| # original class_prefix | |||||
| return True | |||||
| return False |
| import json | |||||
| from flask import current_app | |||||
| from langchain.embeddings.base import Embeddings | |||||
| from core.index.vector_index.base import BaseVectorIndex | |||||
| from extensions.ext_database import db | |||||
| from models.dataset import Dataset, Document | |||||
| class VectorIndex: | |||||
| def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings): | |||||
| self._dataset = dataset | |||||
| self._embeddings = embeddings | |||||
| self._vector_index = self._init_vector_index(dataset, config, embeddings) | |||||
| def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings) -> BaseVectorIndex: | |||||
| vector_type = config.get('VECTOR_STORE') | |||||
| if self._dataset.index_struct_dict: | |||||
| vector_type = self._dataset.index_struct_dict['type'] | |||||
| if not vector_type: | |||||
| raise ValueError(f"Vector store must be specified.") | |||||
| if vector_type == "weaviate": | |||||
| from core.index.vector_index.weaviate_vector_index import WeaviateVectorIndex, WeaviateConfig | |||||
| return WeaviateVectorIndex( | |||||
| dataset=dataset, | |||||
| config=WeaviateConfig( | |||||
| endpoint=config.get('WEAVIATE_ENDPOINT'), | |||||
| api_key=config.get('WEAVIATE_API_KEY'), | |||||
| batch_size=int(config.get('WEAVIATE_BATCH_SIZE')) | |||||
| ), | |||||
| embeddings=embeddings | |||||
| ) | |||||
| elif vector_type == "qdrant": | |||||
| from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig | |||||
| return QdrantVectorIndex( | |||||
| dataset=dataset, | |||||
| config=QdrantConfig( | |||||
| endpoint=config.get('QDRANT_URL'), | |||||
| api_key=config.get('QDRANT_API_KEY'), | |||||
| root_path=current_app.root_path | |||||
| ), | |||||
| embeddings=embeddings | |||||
| ) | |||||
| else: | |||||
| raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") | |||||
| def add_texts(self, texts: list[Document], **kwargs): | |||||
| if not self._dataset.index_struct_dict: | |||||
| self._vector_index.create(texts, **kwargs) | |||||
| self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct()) | |||||
| db.session.commit() | |||||
| return | |||||
| self._vector_index.add_texts(texts, **kwargs) | |||||
| def __getattr__(self, name): | |||||
| if self._vector_index is not None: | |||||
| method = getattr(self._vector_index, name) | |||||
| if callable(method): | |||||
| return method | |||||
| raise AttributeError(f"'VectorIndex' object has no attribute '{name}'") | |||||
| from typing import Optional, cast | |||||
| import weaviate | |||||
| from langchain.embeddings.base import Embeddings | |||||
| from langchain.schema import Document, BaseRetriever | |||||
| from langchain.vectorstores import VectorStore | |||||
| from pydantic import BaseModel, root_validator | |||||
| from core.index.base import BaseIndex | |||||
| from core.index.vector_index.base import BaseVectorIndex | |||||
| from core.vector_store.weaviate_vector_store import WeaviateVectorStore | |||||
| from models.dataset import Dataset | |||||
| class WeaviateConfig(BaseModel): | |||||
| endpoint: str | |||||
| api_key: Optional[str] | |||||
| batch_size: int = 100 | |||||
| @root_validator() | |||||
| def validate_config(cls, values: dict) -> dict: | |||||
| if not values['endpoint']: | |||||
| raise ValueError("config WEAVIATE_ENDPOINT is required") | |||||
| return values | |||||
| class WeaviateVectorIndex(BaseVectorIndex): | |||||
| def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings): | |||||
| super().__init__(dataset, embeddings) | |||||
| self._client = self._init_client(config) | |||||
| def _init_client(self, config: WeaviateConfig) -> weaviate.Client: | |||||
| auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key) | |||||
| weaviate.connect.connection.has_grpc = False | |||||
| client = weaviate.Client( | |||||
| url=config.endpoint, | |||||
| auth_client_secret=auth_config, | |||||
| timeout_config=(5, 60), | |||||
| startup_period=None | |||||
| ) | |||||
| client.batch.configure( | |||||
| # `batch_size` takes an `int` value to enable auto-batching | |||||
| # (`None` is used for manual batching) | |||||
| batch_size=config.batch_size, | |||||
| # dynamically update the `batch_size` based on import speed | |||||
| dynamic=True, | |||||
| # `timeout_retries` takes an `int` value to retry on time outs | |||||
| timeout_retries=3, | |||||
| ) | |||||
| return client | |||||
| def get_type(self) -> str: | |||||
| return 'weaviate' | |||||
| def get_index_name(self, dataset: Dataset) -> str: | |||||
| if self.dataset.index_struct_dict: | |||||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| if not class_prefix.endswith('_Node'): | |||||
| # original class_prefix | |||||
| class_prefix += '_Node' | |||||
| return class_prefix | |||||
| dataset_id = dataset.id | |||||
| return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' | |||||
| def to_index_struct(self) -> dict: | |||||
| return { | |||||
| "type": self.get_type(), | |||||
| "vector_store": {"class_prefix": self.get_index_name(self.dataset)} | |||||
| } | |||||
| def create(self, texts: list[Document], **kwargs) -> BaseIndex: | |||||
| uuids = self._get_uuids(texts) | |||||
| self._vector_store = WeaviateVectorStore.from_documents( | |||||
| texts, | |||||
| self._embeddings, | |||||
| client=self._client, | |||||
| index_name=self.get_index_name(self.dataset), | |||||
| uuids=uuids, | |||||
| by_text=False | |||||
| ) | |||||
| return self | |||||
| def _get_vector_store(self) -> VectorStore: | |||||
| """Only for created index.""" | |||||
| if self._vector_store: | |||||
| return self._vector_store | |||||
| attributes = ['doc_id', 'dataset_id', 'document_id'] | |||||
| if self._is_origin(): | |||||
| attributes = ['doc_id'] | |||||
| return WeaviateVectorStore( | |||||
| client=self._client, | |||||
| index_name=self.get_index_name(self.dataset), | |||||
| text_key='text', | |||||
| embedding=self._embeddings, | |||||
| attributes=attributes, | |||||
| by_text=False | |||||
| ) | |||||
| def _get_vector_store_class(self) -> type: | |||||
| return WeaviateVectorStore | |||||
| def delete_by_document_id(self, document_id: str): | |||||
| if self._is_origin(): | |||||
| self.recreate_dataset(self.dataset) | |||||
| return | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| vector_store.del_texts({ | |||||
| "operator": "Equal", | |||||
| "path": ["document_id"], | |||||
| "valueText": document_id | |||||
| }) | |||||
| def _is_origin(self): | |||||
| if self.dataset.index_struct_dict: | |||||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| if not class_prefix.endswith('_Node'): | |||||
| # original class_prefix | |||||
| return True | |||||
| return False |
| import datetime | import datetime | ||||
| import json | import json | ||||
| import logging | |||||
| import re | import re | ||||
| import tempfile | |||||
| import time | import time | ||||
| from pathlib import Path | |||||
| from typing import Optional, List | |||||
| import uuid | |||||
| from typing import Optional, List, cast | |||||
| from flask import current_app | |||||
| from flask_login import current_user | from flask_login import current_user | ||||
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |||||
| from langchain.embeddings import OpenAIEmbeddings | |||||
| from langchain.schema import Document | |||||
| from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter | |||||
| from llama_index import SimpleDirectoryReader | |||||
| from llama_index.data_structs import Node | |||||
| from llama_index.data_structs.node_v2 import DocumentRelationship | |||||
| from llama_index.node_parser import SimpleNodeParser, NodeParser | |||||
| from llama_index.readers.file.base import DEFAULT_FILE_EXTRACTOR | |||||
| from llama_index.readers.file.markdown_parser import MarkdownParser | |||||
| from core.data_source.notion import NotionPageReader | |||||
| from core.index.readers.xlsx_parser import XLSXParser | |||||
| from core.data_loader.file_extractor import FileExtractor | |||||
| from core.data_loader.loader.notion import NotionLoader | |||||
| from core.docstore.dataset_docstore import DatesetDocumentStore | from core.docstore.dataset_docstore import DatesetDocumentStore | ||||
| from core.index.keyword_table_index import KeywordTableIndex | |||||
| from core.index.readers.html_parser import HTMLParser | |||||
| from core.index.readers.markdown_parser import MarkdownParser | |||||
| from core.index.readers.pdf_parser import PDFParser | |||||
| from core.index.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter | |||||
| from core.index.vector_index import VectorIndex | |||||
| from core.embedding.cached_embedding import CacheEmbedding | |||||
| from core.index.index import IndexBuilder | |||||
| from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig | |||||
| from core.index.vector_index.vector_index import VectorIndex | |||||
| from core.llm.error import ProviderTokenNotInitError | |||||
| from core.llm.llm_builder import LLMBuilder | |||||
| from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter | |||||
| from core.llm.token_calculator import TokenCalculator | from core.llm.token_calculator import TokenCalculator | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from extensions.ext_storage import storage | from extensions.ext_storage import storage | ||||
| from models.dataset import Document, Dataset, DocumentSegment, DatasetProcessRule | |||||
| from libs import helper | |||||
| from models.dataset import Document as DatasetDocument | |||||
| from models.dataset import Dataset, DocumentSegment, DatasetProcessRule | |||||
| from models.model import UploadFile | from models.model import UploadFile | ||||
| from models.source import DataSourceBinding | from models.source import DataSourceBinding | ||||
| self.storage = storage | self.storage = storage | ||||
| self.embedding_model_name = embedding_model_name | self.embedding_model_name = embedding_model_name | ||||
| def run(self, documents: List[Document]): | |||||
| def run(self, dataset_documents: List[DatasetDocument]): | |||||
| """Run the indexing process.""" | """Run the indexing process.""" | ||||
| for document in documents: | |||||
| for dataset_document in dataset_documents: | |||||
| try: | |||||
| # get dataset | |||||
| dataset = Dataset.query.filter_by( | |||||
| id=dataset_document.dataset_id | |||||
| ).first() | |||||
| if not dataset: | |||||
| raise ValueError("no dataset found") | |||||
| # load file | |||||
| text_docs = self._load_data(dataset_document) | |||||
| # get the process rule | |||||
| processing_rule = db.session.query(DatasetProcessRule). \ | |||||
| filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ | |||||
| first() | |||||
| # get splitter | |||||
| splitter = self._get_splitter(processing_rule) | |||||
| # split to documents | |||||
| documents = self._step_split( | |||||
| text_docs=text_docs, | |||||
| splitter=splitter, | |||||
| dataset=dataset, | |||||
| dataset_document=dataset_document, | |||||
| processing_rule=processing_rule | |||||
| ) | |||||
| # build index | |||||
| self._build_index( | |||||
| dataset=dataset, | |||||
| dataset_document=dataset_document, | |||||
| documents=documents | |||||
| ) | |||||
| except DocumentIsPausedException: | |||||
| raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) | |||||
| except ProviderTokenNotInitError as e: | |||||
| dataset_document.indexing_status = 'error' | |||||
| dataset_document.error = str(e.description) | |||||
| dataset_document.stopped_at = datetime.datetime.utcnow() | |||||
| db.session.commit() | |||||
| except Exception as e: | |||||
| logging.exception("consume document failed") | |||||
| dataset_document.indexing_status = 'error' | |||||
| dataset_document.error = str(e) | |||||
| dataset_document.stopped_at = datetime.datetime.utcnow() | |||||
| db.session.commit() | |||||
| def run_in_splitting_status(self, dataset_document: DatasetDocument): | |||||
| """Run the indexing process when the index_status is splitting.""" | |||||
| try: | |||||
| # get dataset | # get dataset | ||||
| dataset = Dataset.query.filter_by( | dataset = Dataset.query.filter_by( | ||||
| id=document.dataset_id | |||||
| id=dataset_document.dataset_id | |||||
| ).first() | ).first() | ||||
| if not dataset: | if not dataset: | ||||
| raise ValueError("no dataset found") | raise ValueError("no dataset found") | ||||
| # get exist document_segment list and delete | |||||
| document_segments = DocumentSegment.query.filter_by( | |||||
| dataset_id=dataset.id, | |||||
| document_id=dataset_document.id | |||||
| ).all() | |||||
| db.session.delete(document_segments) | |||||
| db.session.commit() | |||||
| # load file | # load file | ||||
| text_docs = self._load_data(document) | |||||
| text_docs = self._load_data(dataset_document) | |||||
| # get the process rule | # get the process rule | ||||
| processing_rule = db.session.query(DatasetProcessRule). \ | processing_rule = db.session.query(DatasetProcessRule). \ | ||||
| filter(DatasetProcessRule.id == document.dataset_process_rule_id). \ | |||||
| filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ | |||||
| first() | first() | ||||
| # get node parser for splitting | |||||
| node_parser = self._get_node_parser(processing_rule) | |||||
| # get splitter | |||||
| splitter = self._get_splitter(processing_rule) | |||||
| # split to nodes | |||||
| nodes = self._step_split( | |||||
| # split to documents | |||||
| documents = self._step_split( | |||||
| text_docs=text_docs, | text_docs=text_docs, | ||||
| node_parser=node_parser, | |||||
| splitter=splitter, | |||||
| dataset=dataset, | dataset=dataset, | ||||
| document=document, | |||||
| dataset_document=dataset_document, | |||||
| processing_rule=processing_rule | processing_rule=processing_rule | ||||
| ) | ) | ||||
| # build index | # build index | ||||
| self._build_index( | self._build_index( | ||||
| dataset=dataset, | dataset=dataset, | ||||
| document=document, | |||||
| nodes=nodes | |||||
| dataset_document=dataset_document, | |||||
| documents=documents | |||||
| ) | ) | ||||
| except DocumentIsPausedException: | |||||
| raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) | |||||
| except ProviderTokenNotInitError as e: | |||||
| dataset_document.indexing_status = 'error' | |||||
| dataset_document.error = str(e.description) | |||||
| dataset_document.stopped_at = datetime.datetime.utcnow() | |||||
| db.session.commit() | |||||
| except Exception as e: | |||||
| logging.exception("consume document failed") | |||||
| dataset_document.indexing_status = 'error' | |||||
| dataset_document.error = str(e) | |||||
| dataset_document.stopped_at = datetime.datetime.utcnow() | |||||
| db.session.commit() | |||||
| def run_in_splitting_status(self, document: Document): | |||||
| """Run the indexing process when the index_status is splitting.""" | |||||
| # get dataset | |||||
| dataset = Dataset.query.filter_by( | |||||
| id=document.dataset_id | |||||
| ).first() | |||||
| if not dataset: | |||||
| raise ValueError("no dataset found") | |||||
| # get exist document_segment list and delete | |||||
| document_segments = DocumentSegment.query.filter_by( | |||||
| dataset_id=dataset.id, | |||||
| document_id=document.id | |||||
| ).all() | |||||
| db.session.delete(document_segments) | |||||
| db.session.commit() | |||||
| # load file | |||||
| text_docs = self._load_data(document) | |||||
| # get the process rule | |||||
| processing_rule = db.session.query(DatasetProcessRule). \ | |||||
| filter(DatasetProcessRule.id == document.dataset_process_rule_id). \ | |||||
| first() | |||||
| # get node parser for splitting | |||||
| node_parser = self._get_node_parser(processing_rule) | |||||
| def run_in_indexing_status(self, dataset_document: DatasetDocument): | |||||
| """Run the indexing process when the index_status is indexing.""" | |||||
| try: | |||||
| # get dataset | |||||
| dataset = Dataset.query.filter_by( | |||||
| id=dataset_document.dataset_id | |||||
| ).first() | |||||
| # split to nodes | |||||
| nodes = self._step_split( | |||||
| text_docs=text_docs, | |||||
| node_parser=node_parser, | |||||
| dataset=dataset, | |||||
| document=document, | |||||
| processing_rule=processing_rule | |||||
| ) | |||||
| if not dataset: | |||||
| raise ValueError("no dataset found") | |||||
| # build index | |||||
| self._build_index( | |||||
| dataset=dataset, | |||||
| document=document, | |||||
| nodes=nodes | |||||
| ) | |||||
| # get exist document_segment list and delete | |||||
| document_segments = DocumentSegment.query.filter_by( | |||||
| dataset_id=dataset.id, | |||||
| document_id=dataset_document.id | |||||
| ).all() | |||||
| documents = [] | |||||
| if document_segments: | |||||
| for document_segment in document_segments: | |||||
| # transform segment to node | |||||
| if document_segment.status != "completed": | |||||
| document = Document( | |||||
| page_content=document_segment.content, | |||||
| metadata={ | |||||
| "doc_id": document_segment.index_node_id, | |||||
| "doc_hash": document_segment.index_node_hash, | |||||
| "document_id": document_segment.document_id, | |||||
| "dataset_id": document_segment.dataset_id, | |||||
| } | |||||
| ) | |||||
| documents.append(document) | |||||
| def run_in_indexing_status(self, document: Document): | |||||
| """Run the indexing process when the index_status is indexing.""" | |||||
| # get dataset | |||||
| dataset = Dataset.query.filter_by( | |||||
| id=document.dataset_id | |||||
| ).first() | |||||
| if not dataset: | |||||
| raise ValueError("no dataset found") | |||||
| # get exist document_segment list and delete | |||||
| document_segments = DocumentSegment.query.filter_by( | |||||
| dataset_id=dataset.id, | |||||
| document_id=document.id | |||||
| ).all() | |||||
| nodes = [] | |||||
| if document_segments: | |||||
| for document_segment in document_segments: | |||||
| # transform segment to node | |||||
| if document_segment.status != "completed": | |||||
| relationships = { | |||||
| DocumentRelationship.SOURCE: document_segment.document_id, | |||||
| } | |||||
| previous_segment = document_segment.previous_segment | |||||
| if previous_segment: | |||||
| relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id | |||||
| next_segment = document_segment.next_segment | |||||
| if next_segment: | |||||
| relationships[DocumentRelationship.NEXT] = next_segment.index_node_id | |||||
| node = Node( | |||||
| doc_id=document_segment.index_node_id, | |||||
| doc_hash=document_segment.index_node_hash, | |||||
| text=document_segment.content, | |||||
| extra_info=None, | |||||
| node_info=None, | |||||
| relationships=relationships | |||||
| ) | |||||
| nodes.append(node) | |||||
| # build index | |||||
| self._build_index( | |||||
| dataset=dataset, | |||||
| document=document, | |||||
| nodes=nodes | |||||
| ) | |||||
| # build index | |||||
| self._build_index( | |||||
| dataset=dataset, | |||||
| dataset_document=dataset_document, | |||||
| documents=documents | |||||
| ) | |||||
| except DocumentIsPausedException: | |||||
| raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) | |||||
| except ProviderTokenNotInitError as e: | |||||
| dataset_document.indexing_status = 'error' | |||||
| dataset_document.error = str(e.description) | |||||
| dataset_document.stopped_at = datetime.datetime.utcnow() | |||||
| db.session.commit() | |||||
| except Exception as e: | |||||
| logging.exception("consume document failed") | |||||
| dataset_document.indexing_status = 'error' | |||||
| dataset_document.error = str(e) | |||||
| dataset_document.stopped_at = datetime.datetime.utcnow() | |||||
| db.session.commit() | |||||
| def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict: | def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict: | ||||
| """ | """ | ||||
| total_segments = 0 | total_segments = 0 | ||||
| for file_detail in file_details: | for file_detail in file_details: | ||||
| # load data from file | # load data from file | ||||
| text_docs = self._load_data_from_file(file_detail) | |||||
| text_docs = FileExtractor.load(file_detail) | |||||
| processing_rule = DatasetProcessRule( | processing_rule = DatasetProcessRule( | ||||
| mode=tmp_processing_rule["mode"], | mode=tmp_processing_rule["mode"], | ||||
| rules=json.dumps(tmp_processing_rule["rules"]) | rules=json.dumps(tmp_processing_rule["rules"]) | ||||
| ) | ) | ||||
| # get node parser for splitting | |||||
| node_parser = self._get_node_parser(processing_rule) | |||||
| # get splitter | |||||
| splitter = self._get_splitter(processing_rule) | |||||
| # split to nodes | |||||
| nodes = self._split_to_nodes( | |||||
| # split to documents | |||||
| documents = self._split_to_documents( | |||||
| text_docs=text_docs, | text_docs=text_docs, | ||||
| node_parser=node_parser, | |||||
| splitter=splitter, | |||||
| processing_rule=processing_rule | processing_rule=processing_rule | ||||
| ) | ) | ||||
| total_segments += len(nodes) | |||||
| for node in nodes: | |||||
| total_segments += len(documents) | |||||
| for document in documents: | |||||
| if len(preview_texts) < 5: | if len(preview_texts) < 5: | ||||
| preview_texts.append(node.get_text()) | |||||
| preview_texts.append(document.page_content) | |||||
| tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) | |||||
| tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content) | |||||
| return { | return { | ||||
| "total_segments": total_segments, | "total_segments": total_segments, | ||||
| ).first() | ).first() | ||||
| if not data_source_binding: | if not data_source_binding: | ||||
| raise ValueError('Data source binding not found.') | raise ValueError('Data source binding not found.') | ||||
| reader = NotionPageReader(integration_token=data_source_binding.access_token) | |||||
| for page in notion_info['pages']: | for page in notion_info['pages']: | ||||
| if page['type'] == 'page': | |||||
| page_ids = [page['page_id']] | |||||
| documents = reader.load_data_as_documents(page_ids=page_ids) | |||||
| elif page['type'] == 'database': | |||||
| documents = reader.load_data_as_documents(database_id=page['page_id']) | |||||
| else: | |||||
| documents = [] | |||||
| loader = NotionLoader( | |||||
| notion_access_token=data_source_binding.access_token, | |||||
| notion_workspace_id=workspace_id, | |||||
| notion_obj_id=page['page_id'], | |||||
| notion_page_type=page['type'] | |||||
| ) | |||||
| documents = loader.load() | |||||
| processing_rule = DatasetProcessRule( | processing_rule = DatasetProcessRule( | ||||
| mode=tmp_processing_rule["mode"], | mode=tmp_processing_rule["mode"], | ||||
| rules=json.dumps(tmp_processing_rule["rules"]) | rules=json.dumps(tmp_processing_rule["rules"]) | ||||
| ) | ) | ||||
| # get node parser for splitting | |||||
| node_parser = self._get_node_parser(processing_rule) | |||||
| # get splitter | |||||
| splitter = self._get_splitter(processing_rule) | |||||
| # split to nodes | |||||
| nodes = self._split_to_nodes( | |||||
| # split to documents | |||||
| documents = self._split_to_documents( | |||||
| text_docs=documents, | text_docs=documents, | ||||
| node_parser=node_parser, | |||||
| splitter=splitter, | |||||
| processing_rule=processing_rule | processing_rule=processing_rule | ||||
| ) | ) | ||||
| total_segments += len(nodes) | |||||
| for node in nodes: | |||||
| total_segments += len(documents) | |||||
| for document in documents: | |||||
| if len(preview_texts) < 5: | if len(preview_texts) < 5: | ||||
| preview_texts.append(node.get_text()) | |||||
| preview_texts.append(document.page_content) | |||||
| tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) | |||||
| tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content) | |||||
| return { | return { | ||||
| "total_segments": total_segments, | "total_segments": total_segments, | ||||
| "preview": preview_texts | "preview": preview_texts | ||||
| } | } | ||||
| def _load_data(self, document: Document) -> List[Document]: | |||||
| def _load_data(self, dataset_document: DatasetDocument) -> List[Document]: | |||||
| # load file | # load file | ||||
| if document.data_source_type not in ["upload_file", "notion_import"]: | |||||
| if dataset_document.data_source_type not in ["upload_file", "notion_import"]: | |||||
| return [] | return [] | ||||
| data_source_info = document.data_source_info_dict | |||||
| data_source_info = dataset_document.data_source_info_dict | |||||
| text_docs = [] | text_docs = [] | ||||
| if document.data_source_type == 'upload_file': | |||||
| if dataset_document.data_source_type == 'upload_file': | |||||
| if not data_source_info or 'upload_file_id' not in data_source_info: | if not data_source_info or 'upload_file_id' not in data_source_info: | ||||
| raise ValueError("no upload file found") | raise ValueError("no upload file found") | ||||
| filter(UploadFile.id == data_source_info['upload_file_id']). \ | filter(UploadFile.id == data_source_info['upload_file_id']). \ | ||||
| one_or_none() | one_or_none() | ||||
| text_docs = self._load_data_from_file(file_detail) | |||||
| elif document.data_source_type == 'notion_import': | |||||
| if not data_source_info or 'notion_page_id' not in data_source_info \ | |||||
| or 'notion_workspace_id' not in data_source_info: | |||||
| raise ValueError("no notion page found") | |||||
| workspace_id = data_source_info['notion_workspace_id'] | |||||
| page_id = data_source_info['notion_page_id'] | |||||
| page_type = data_source_info['type'] | |||||
| data_source_binding = DataSourceBinding.query.filter( | |||||
| db.and_( | |||||
| DataSourceBinding.tenant_id == document.tenant_id, | |||||
| DataSourceBinding.provider == 'notion', | |||||
| DataSourceBinding.disabled == False, | |||||
| DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' | |||||
| ) | |||||
| ).first() | |||||
| if not data_source_binding: | |||||
| raise ValueError('Data source binding not found.') | |||||
| if page_type == 'page': | |||||
| # add page last_edited_time to data_source_info | |||||
| self._get_notion_page_last_edited_time(page_id, data_source_binding.access_token, document) | |||||
| text_docs = self._load_page_data_from_notion(page_id, data_source_binding.access_token) | |||||
| elif page_type == 'database': | |||||
| # add page last_edited_time to data_source_info | |||||
| self._get_notion_database_last_edited_time(page_id, data_source_binding.access_token, document) | |||||
| text_docs = self._load_database_data_from_notion(page_id, data_source_binding.access_token) | |||||
| text_docs = FileExtractor.load(file_detail) | |||||
| elif dataset_document.data_source_type == 'notion_import': | |||||
| loader = NotionLoader.from_document(dataset_document) | |||||
| text_docs = loader.load() | |||||
| # update document status to splitting | # update document status to splitting | ||||
| self._update_document_index_status( | self._update_document_index_status( | ||||
| document_id=document.id, | |||||
| document_id=dataset_document.id, | |||||
| after_indexing_status="splitting", | after_indexing_status="splitting", | ||||
| extra_update_params={ | extra_update_params={ | ||||
| Document.word_count: sum([len(text_doc.text) for text_doc in text_docs]), | |||||
| Document.parsing_completed_at: datetime.datetime.utcnow() | |||||
| DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]), | |||||
| DatasetDocument.parsing_completed_at: datetime.datetime.utcnow() | |||||
| } | } | ||||
| ) | ) | ||||
| # replace doc id to document model id | # replace doc id to document model id | ||||
| text_docs = cast(List[Document], text_docs) | |||||
| for text_doc in text_docs: | for text_doc in text_docs: | ||||
| # remove invalid symbol | # remove invalid symbol | ||||
| text_doc.text = self.filter_string(text_doc.get_text()) | |||||
| text_doc.doc_id = document.id | |||||
| text_doc.page_content = self.filter_string(text_doc.page_content) | |||||
| text_doc.metadata['document_id'] = dataset_document.id | |||||
| text_doc.metadata['dataset_id'] = dataset_document.dataset_id | |||||
| return text_docs | return text_docs | ||||
| pattern = re.compile('[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]') | pattern = re.compile('[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]') | ||||
| return pattern.sub('', text) | return pattern.sub('', text) | ||||
| def _load_data_from_file(self, upload_file: UploadFile) -> List[Document]: | |||||
| with tempfile.TemporaryDirectory() as temp_dir: | |||||
| suffix = Path(upload_file.key).suffix | |||||
| filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" | |||||
| self.storage.download(upload_file.key, filepath) | |||||
| file_extractor = DEFAULT_FILE_EXTRACTOR.copy() | |||||
| file_extractor[".markdown"] = MarkdownParser() | |||||
| file_extractor[".md"] = MarkdownParser() | |||||
| file_extractor[".html"] = HTMLParser() | |||||
| file_extractor[".htm"] = HTMLParser() | |||||
| file_extractor[".pdf"] = PDFParser({'upload_file': upload_file}) | |||||
| file_extractor[".xlsx"] = XLSXParser() | |||||
| loader = SimpleDirectoryReader(input_files=[filepath], file_extractor=file_extractor) | |||||
| text_docs = loader.load_data() | |||||
| return text_docs | |||||
| def _load_page_data_from_notion(self, page_id: str, access_token: str) -> List[Document]: | |||||
| page_ids = [page_id] | |||||
| reader = NotionPageReader(integration_token=access_token) | |||||
| text_docs = reader.load_data_as_documents(page_ids=page_ids) | |||||
| return text_docs | |||||
| def _load_database_data_from_notion(self, database_id: str, access_token: str) -> List[Document]: | |||||
| reader = NotionPageReader(integration_token=access_token) | |||||
| text_docs = reader.load_data_as_documents(database_id=database_id) | |||||
| return text_docs | |||||
| def _get_notion_page_last_edited_time(self, page_id: str, access_token: str, document: Document): | |||||
| reader = NotionPageReader(integration_token=access_token) | |||||
| last_edited_time = reader.get_page_last_edited_time(page_id) | |||||
| data_source_info = document.data_source_info_dict | |||||
| data_source_info['last_edited_time'] = last_edited_time | |||||
| update_params = { | |||||
| Document.data_source_info: json.dumps(data_source_info) | |||||
| } | |||||
| Document.query.filter_by(id=document.id).update(update_params) | |||||
| db.session.commit() | |||||
| def _get_notion_database_last_edited_time(self, page_id: str, access_token: str, document: Document): | |||||
| reader = NotionPageReader(integration_token=access_token) | |||||
| last_edited_time = reader.get_database_last_edited_time(page_id) | |||||
| data_source_info = document.data_source_info_dict | |||||
| data_source_info['last_edited_time'] = last_edited_time | |||||
| update_params = { | |||||
| Document.data_source_info: json.dumps(data_source_info) | |||||
| } | |||||
| Document.query.filter_by(id=document.id).update(update_params) | |||||
| db.session.commit() | |||||
| def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser: | |||||
| def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter: | |||||
| """ | """ | ||||
| Get the NodeParser object according to the processing rule. | Get the NodeParser object according to the processing rule. | ||||
| """ | """ | ||||
| separators=["\n\n", "。", ".", " ", ""] | separators=["\n\n", "。", ".", " ", ""] | ||||
| ) | ) | ||||
| return SimpleNodeParser(text_splitter=character_splitter, include_extra_info=True) | |||||
| return character_splitter | |||||
| def _step_split(self, text_docs: List[Document], node_parser: NodeParser, | |||||
| dataset: Dataset, document: Document, processing_rule: DatasetProcessRule) -> List[Node]: | |||||
| def _step_split(self, text_docs: List[Document], splitter: TextSplitter, | |||||
| dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \ | |||||
| -> List[Document]: | |||||
| """ | """ | ||||
| Split the text documents into nodes and save them to the document segment. | |||||
| Split the text documents into documents and save them to the document segment. | |||||
| """ | """ | ||||
| nodes = self._split_to_nodes( | |||||
| documents = self._split_to_documents( | |||||
| text_docs=text_docs, | text_docs=text_docs, | ||||
| node_parser=node_parser, | |||||
| splitter=splitter, | |||||
| processing_rule=processing_rule | processing_rule=processing_rule | ||||
| ) | ) | ||||
| # save node to document segment | # save node to document segment | ||||
| doc_store = DatesetDocumentStore( | doc_store = DatesetDocumentStore( | ||||
| dataset=dataset, | dataset=dataset, | ||||
| user_id=document.created_by, | |||||
| user_id=dataset_document.created_by, | |||||
| embedding_model_name=self.embedding_model_name, | embedding_model_name=self.embedding_model_name, | ||||
| document_id=document.id | |||||
| document_id=dataset_document.id | |||||
| ) | ) | ||||
| # add document segments | # add document segments | ||||
| doc_store.add_documents(nodes) | |||||
| doc_store.add_documents(documents) | |||||
| # update document status to indexing | # update document status to indexing | ||||
| cur_time = datetime.datetime.utcnow() | cur_time = datetime.datetime.utcnow() | ||||
| self._update_document_index_status( | self._update_document_index_status( | ||||
| document_id=document.id, | |||||
| document_id=dataset_document.id, | |||||
| after_indexing_status="indexing", | after_indexing_status="indexing", | ||||
| extra_update_params={ | extra_update_params={ | ||||
| Document.cleaning_completed_at: cur_time, | |||||
| Document.splitting_completed_at: cur_time, | |||||
| DatasetDocument.cleaning_completed_at: cur_time, | |||||
| DatasetDocument.splitting_completed_at: cur_time, | |||||
| } | } | ||||
| ) | ) | ||||
| # update segment status to indexing | # update segment status to indexing | ||||
| self._update_segments_by_document( | self._update_segments_by_document( | ||||
| document_id=document.id, | |||||
| dataset_document_id=dataset_document.id, | |||||
| update_params={ | update_params={ | ||||
| DocumentSegment.status: "indexing", | DocumentSegment.status: "indexing", | ||||
| DocumentSegment.indexing_at: datetime.datetime.utcnow() | DocumentSegment.indexing_at: datetime.datetime.utcnow() | ||||
| } | } | ||||
| ) | ) | ||||
| return nodes | |||||
| return documents | |||||
| def _split_to_nodes(self, text_docs: List[Document], node_parser: NodeParser, | |||||
| processing_rule: DatasetProcessRule) -> List[Node]: | |||||
| def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter, | |||||
| processing_rule: DatasetProcessRule) -> List[Document]: | |||||
| """ | """ | ||||
| Split the text documents into nodes. | Split the text documents into nodes. | ||||
| """ | """ | ||||
| all_nodes = [] | |||||
| all_documents = [] | |||||
| for text_doc in text_docs: | for text_doc in text_docs: | ||||
| # document clean | # document clean | ||||
| document_text = self._document_clean(text_doc.get_text(), processing_rule) | |||||
| text_doc.text = document_text | |||||
| document_text = self._document_clean(text_doc.page_content, processing_rule) | |||||
| text_doc.page_content = document_text | |||||
| # parse document to nodes | # parse document to nodes | ||||
| nodes = node_parser.get_nodes_from_documents([text_doc]) | |||||
| nodes = [node for node in nodes if node.text is not None and node.text.strip()] | |||||
| all_nodes.extend(nodes) | |||||
| documents = splitter.split_documents([text_doc]) | |||||
| split_documents = [] | |||||
| for document in documents: | |||||
| if document.page_content is None or not document.page_content.strip(): | |||||
| continue | |||||
| doc_id = str(uuid.uuid4()) | |||||
| hash = helper.generate_text_hash(document.page_content) | |||||
| document.metadata['doc_id'] = doc_id | |||||
| document.metadata['doc_hash'] = hash | |||||
| split_documents.append(document) | |||||
| all_documents.extend(split_documents) | |||||
| return all_nodes | |||||
| return all_documents | |||||
| def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str: | def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str: | ||||
| """ | """ | ||||
| return text | return text | ||||
| def _build_index(self, dataset: Dataset, document: Document, nodes: List[Node]) -> None: | |||||
| def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None: | |||||
| """ | """ | ||||
| Build the index for the document. | Build the index for the document. | ||||
| """ | """ | ||||
| vector_index = VectorIndex(dataset=dataset) | |||||
| keyword_table_index = KeywordTableIndex(dataset=dataset) | |||||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||||
| keyword_table_index = IndexBuilder.get_index(dataset, 'economy') | |||||
| # chunk nodes by chunk size | # chunk nodes by chunk size | ||||
| indexing_start_at = time.perf_counter() | indexing_start_at = time.perf_counter() | ||||
| tokens = 0 | tokens = 0 | ||||
| chunk_size = 100 | chunk_size = 100 | ||||
| for i in range(0, len(nodes), chunk_size): | |||||
| for i in range(0, len(documents), chunk_size): | |||||
| # check document is paused | # check document is paused | ||||
| self._check_document_paused_status(document.id) | |||||
| chunk_nodes = nodes[i:i + chunk_size] | |||||
| self._check_document_paused_status(dataset_document.id) | |||||
| chunk_documents = documents[i:i + chunk_size] | |||||
| tokens += sum( | tokens += sum( | ||||
| TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) for node in chunk_nodes | |||||
| TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content) | |||||
| for document in chunk_documents | |||||
| ) | ) | ||||
| # save vector index | # save vector index | ||||
| if dataset.indexing_technique == "high_quality": | |||||
| vector_index.add_nodes(chunk_nodes) | |||||
| if vector_index: | |||||
| vector_index.add_texts(chunk_documents) | |||||
| # save keyword index | # save keyword index | ||||
| keyword_table_index.add_nodes(chunk_nodes) | |||||
| keyword_table_index.add_texts(chunk_documents) | |||||
| node_ids = [node.doc_id for node in chunk_nodes] | |||||
| document_ids = [document.metadata['doc_id'] for document in chunk_documents] | |||||
| db.session.query(DocumentSegment).filter( | db.session.query(DocumentSegment).filter( | ||||
| DocumentSegment.document_id == document.id, | |||||
| DocumentSegment.index_node_id.in_(node_ids), | |||||
| DocumentSegment.document_id == dataset_document.id, | |||||
| DocumentSegment.index_node_id.in_(document_ids), | |||||
| DocumentSegment.status == "indexing" | DocumentSegment.status == "indexing" | ||||
| ).update({ | ).update({ | ||||
| DocumentSegment.status: "completed", | DocumentSegment.status: "completed", | ||||
| # update document status to completed | # update document status to completed | ||||
| self._update_document_index_status( | self._update_document_index_status( | ||||
| document_id=document.id, | |||||
| document_id=dataset_document.id, | |||||
| after_indexing_status="completed", | after_indexing_status="completed", | ||||
| extra_update_params={ | extra_update_params={ | ||||
| Document.tokens: tokens, | |||||
| Document.completed_at: datetime.datetime.utcnow(), | |||||
| Document.indexing_latency: indexing_end_at - indexing_start_at, | |||||
| DatasetDocument.tokens: tokens, | |||||
| DatasetDocument.completed_at: datetime.datetime.utcnow(), | |||||
| DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at, | |||||
| } | } | ||||
| ) | ) | ||||
| """ | """ | ||||
| Update the document indexing status. | Update the document indexing status. | ||||
| """ | """ | ||||
| count = Document.query.filter_by(id=document_id, is_paused=True).count() | |||||
| count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count() | |||||
| if count > 0: | if count > 0: | ||||
| raise DocumentIsPausedException() | raise DocumentIsPausedException() | ||||
| update_params = { | update_params = { | ||||
| Document.indexing_status: after_indexing_status | |||||
| DatasetDocument.indexing_status: after_indexing_status | |||||
| } | } | ||||
| if extra_update_params: | if extra_update_params: | ||||
| update_params.update(extra_update_params) | update_params.update(extra_update_params) | ||||
| Document.query.filter_by(id=document_id).update(update_params) | |||||
| DatasetDocument.query.filter_by(id=document_id).update(update_params) | |||||
| db.session.commit() | db.session.commit() | ||||
| def _update_segments_by_document(self, document_id: str, update_params: dict) -> None: | |||||
| def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None: | |||||
| """ | """ | ||||
| Update the document segment by document id. | Update the document segment by document id. | ||||
| """ | """ | ||||
| DocumentSegment.query.filter_by(document_id=document_id).update(update_params) | |||||
| DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) | |||||
| db.session.commit() | db.session.commit() | ||||
| from typing import Union, Optional | |||||
| from typing import Union, Optional, List | |||||
| from langchain.callbacks import CallbackManager | |||||
| from langchain.llms.fake import FakeListLLM | |||||
| from langchain.callbacks.base import BaseCallbackHandler | |||||
| from core.constant import llm_constant | from core.constant import llm_constant | ||||
| from core.llm.error import ProviderTokenNotInitError | from core.llm.error import ProviderTokenNotInitError | ||||
| """ | """ | ||||
| @classmethod | @classmethod | ||||
| def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI, FakeListLLM]: | |||||
| if model_name == 'fake': | |||||
| return FakeListLLM(responses=[]) | |||||
| def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]: | |||||
| provider = cls.get_default_provider(tenant_id) | provider = cls.get_default_provider(tenant_id) | ||||
| model_credentials = cls.get_model_credentials(tenant_id, provider, model_name) | |||||
| mode = cls.get_mode_by_model(model_name) | mode = cls.get_mode_by_model(model_name) | ||||
| if mode == 'chat': | if mode == 'chat': | ||||
| if provider == 'openai': | if provider == 'openai': | ||||
| else: | else: | ||||
| raise ValueError(f"model name {model_name} is not supported.") | raise ValueError(f"model name {model_name} is not supported.") | ||||
| model_credentials = cls.get_model_credentials(tenant_id, provider, model_name) | |||||
| model_kwargs = { | |||||
| 'top_p': kwargs.get('top_p', 1), | |||||
| 'frequency_penalty': kwargs.get('frequency_penalty', 0), | |||||
| 'presence_penalty': kwargs.get('presence_penalty', 0), | |||||
| } | |||||
| model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs} | |||||
| return llm_cls( | return llm_cls( | ||||
| model_name=model_name, | model_name=model_name, | ||||
| temperature=kwargs.get('temperature', 0), | temperature=kwargs.get('temperature', 0), | ||||
| max_tokens=kwargs.get('max_tokens', 256), | max_tokens=kwargs.get('max_tokens', 256), | ||||
| top_p=kwargs.get('top_p', 1), | |||||
| frequency_penalty=kwargs.get('frequency_penalty', 0), | |||||
| presence_penalty=kwargs.get('presence_penalty', 0), | |||||
| callback_manager=kwargs.get('callback_manager', None), | |||||
| **model_extras_kwargs, | |||||
| callbacks=kwargs.get('callbacks', None), | |||||
| streaming=kwargs.get('streaming', False), | streaming=kwargs.get('streaming', False), | ||||
| # request_timeout=None | # request_timeout=None | ||||
| **model_credentials | **model_credentials | ||||
| @classmethod | @classmethod | ||||
| def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False, | def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False, | ||||
| callback_manager: Optional[CallbackManager] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]: | |||||
| callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]: | |||||
| model_name = model.get("name") | model_name = model.get("name") | ||||
| completion_params = model.get("completion_params", {}) | completion_params = model.get("completion_params", {}) | ||||
| frequency_penalty=completion_params.get('frequency_penalty', 0.1), | frequency_penalty=completion_params.get('frequency_penalty', 0.1), | ||||
| presence_penalty=completion_params.get('presence_penalty', 0.1), | presence_penalty=completion_params.get('presence_penalty', 0.1), | ||||
| streaming=streaming, | streaming=streaming, | ||||
| callback_manager=callback_manager | |||||
| callbacks=callbacks | |||||
| ) | ) | ||||
| @classmethod | @classmethod |
| """ | """ | ||||
| config = self.get_provider_api_key(model_id=model_id) | config = self.get_provider_api_key(model_id=model_id) | ||||
| config['openai_api_type'] = 'azure' | config['openai_api_type'] = 'azure' | ||||
| config['deployment_name'] = model_id.replace('.', '') if model_id else None | |||||
| if model_id == 'text-embedding-ada-002': | |||||
| config['deployment'] = model_id.replace('.', '') if model_id else None | |||||
| else: | |||||
| config['deployment_name'] = model_id.replace('.', '') if model_id else None | |||||
| return config | return config | ||||
| def get_provider_name(self): | def get_provider_name(self): |
| from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks | |||||
| from langchain.schema import BaseMessage, ChatResult, LLMResult | from langchain.schema import BaseMessage, ChatResult, LLMResult | ||||
| from langchain.chat_models import AzureChatOpenAI | from langchain.chat_models import AzureChatOpenAI | ||||
| from typing import Optional, List, Dict, Any | from typing import Optional, List, Dict, Any | ||||
| return message_tokens | return message_tokens | ||||
| def _generate( | |||||
| self, messages: List[BaseMessage], stop: Optional[List[str]] = None | |||||
| ) -> ChatResult: | |||||
| self.callback_manager.on_llm_start( | |||||
| {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], | |||||
| verbose=self.verbose | |||||
| ) | |||||
| chat_result = super()._generate(messages, stop) | |||||
| result = LLMResult( | |||||
| generations=[chat_result.generations], | |||||
| llm_output=chat_result.llm_output | |||||
| ) | |||||
| self.callback_manager.on_llm_end(result, verbose=self.verbose) | |||||
| return chat_result | |||||
| async def _agenerate( | |||||
| self, messages: List[BaseMessage], stop: Optional[List[str]] = None | |||||
| ) -> ChatResult: | |||||
| if self.callback_manager.is_async: | |||||
| await self.callback_manager.on_llm_start( | |||||
| {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], | |||||
| verbose=self.verbose | |||||
| ) | |||||
| else: | |||||
| self.callback_manager.on_llm_start( | |||||
| {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], | |||||
| verbose=self.verbose | |||||
| ) | |||||
| chat_result = super()._generate(messages, stop) | |||||
| result = LLMResult( | |||||
| generations=[chat_result.generations], | |||||
| llm_output=chat_result.llm_output | |||||
| ) | |||||
| if self.callback_manager.is_async: | |||||
| await self.callback_manager.on_llm_end(result, verbose=self.verbose) | |||||
| else: | |||||
| self.callback_manager.on_llm_end(result, verbose=self.verbose) | |||||
| return chat_result | |||||
| @handle_llm_exceptions | @handle_llm_exceptions | ||||
| def generate( | def generate( | ||||
| self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None | |||||
| self, | |||||
| messages: List[List[BaseMessage]], | |||||
| stop: Optional[List[str]] = None, | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs: Any, | |||||
| ) -> LLMResult: | ) -> LLMResult: | ||||
| return super().generate(messages, stop) | |||||
| return super().generate(messages, stop, callbacks, **kwargs) | |||||
| @handle_llm_exceptions_async | @handle_llm_exceptions_async | ||||
| async def agenerate( | async def agenerate( | ||||
| self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None | |||||
| self, | |||||
| messages: List[List[BaseMessage]], | |||||
| stop: Optional[List[str]] = None, | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs: Any, | |||||
| ) -> LLMResult: | ) -> LLMResult: | ||||
| return await super().agenerate(messages, stop) | |||||
| return await super().agenerate(messages, stop, callbacks, **kwargs) |
| import os | |||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.llms import AzureOpenAI | from langchain.llms import AzureOpenAI | ||||
| from langchain.schema import LLMResult | from langchain.schema import LLMResult | ||||
| from typing import Optional, List, Dict, Mapping, Any | from typing import Optional, List, Dict, Mapping, Any | ||||
| @handle_llm_exceptions | @handle_llm_exceptions | ||||
| def generate( | def generate( | ||||
| self, prompts: List[str], stop: Optional[List[str]] = None | |||||
| self, | |||||
| prompts: List[str], | |||||
| stop: Optional[List[str]] = None, | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs: Any, | |||||
| ) -> LLMResult: | ) -> LLMResult: | ||||
| return super().generate(prompts, stop) | |||||
| return super().generate(prompts, stop, callbacks, **kwargs) | |||||
| @handle_llm_exceptions_async | @handle_llm_exceptions_async | ||||
| async def agenerate( | async def agenerate( | ||||
| self, prompts: List[str], stop: Optional[List[str]] = None | |||||
| self, | |||||
| prompts: List[str], | |||||
| stop: Optional[List[str]] = None, | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs: Any, | |||||
| ) -> LLMResult: | ) -> LLMResult: | ||||
| return await super().agenerate(prompts, stop) | |||||
| return await super().agenerate(prompts, stop, callbacks, **kwargs) |
| import os | import os | ||||
| from langchain.schema import BaseMessage, ChatResult, LLMResult | |||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.schema import BaseMessage, LLMResult | |||||
| from langchain.chat_models import ChatOpenAI | from langchain.chat_models import ChatOpenAI | ||||
| from typing import Optional, List, Dict, Any | from typing import Optional, List, Dict, Any | ||||
| return message_tokens | return message_tokens | ||||
| def _generate( | |||||
| self, messages: List[BaseMessage], stop: Optional[List[str]] = None | |||||
| ) -> ChatResult: | |||||
| self.callback_manager.on_llm_start( | |||||
| {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose | |||||
| ) | |||||
| chat_result = super()._generate(messages, stop) | |||||
| result = LLMResult( | |||||
| generations=[chat_result.generations], | |||||
| llm_output=chat_result.llm_output | |||||
| ) | |||||
| self.callback_manager.on_llm_end(result, verbose=self.verbose) | |||||
| return chat_result | |||||
| async def _agenerate( | |||||
| self, messages: List[BaseMessage], stop: Optional[List[str]] = None | |||||
| ) -> ChatResult: | |||||
| if self.callback_manager.is_async: | |||||
| await self.callback_manager.on_llm_start( | |||||
| {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose | |||||
| ) | |||||
| else: | |||||
| self.callback_manager.on_llm_start( | |||||
| {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose | |||||
| ) | |||||
| chat_result = super()._generate(messages, stop) | |||||
| result = LLMResult( | |||||
| generations=[chat_result.generations], | |||||
| llm_output=chat_result.llm_output | |||||
| ) | |||||
| if self.callback_manager.is_async: | |||||
| await self.callback_manager.on_llm_end(result, verbose=self.verbose) | |||||
| else: | |||||
| self.callback_manager.on_llm_end(result, verbose=self.verbose) | |||||
| return chat_result | |||||
| @handle_llm_exceptions | @handle_llm_exceptions | ||||
| def generate( | def generate( | ||||
| self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None | |||||
| self, | |||||
| messages: List[List[BaseMessage]], | |||||
| stop: Optional[List[str]] = None, | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs: Any, | |||||
| ) -> LLMResult: | ) -> LLMResult: | ||||
| return super().generate(messages, stop) | |||||
| return super().generate(messages, stop, callbacks, **kwargs) | |||||
| @handle_llm_exceptions_async | @handle_llm_exceptions_async | ||||
| async def agenerate( | async def agenerate( | ||||
| self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None | |||||
| self, | |||||
| messages: List[List[BaseMessage]], | |||||
| stop: Optional[List[str]] = None, | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs: Any, | |||||
| ) -> LLMResult: | ) -> LLMResult: | ||||
| return await super().agenerate(messages, stop) | |||||
| return await super().agenerate(messages, stop, callbacks, **kwargs) |
| import os | import os | ||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.schema import LLMResult | from langchain.schema import LLMResult | ||||
| from typing import Optional, List, Dict, Any, Mapping | from typing import Optional, List, Dict, Any, Mapping | ||||
| from langchain import OpenAI | from langchain import OpenAI | ||||
| "organization": self.openai_organization if self.openai_organization else None, | "organization": self.openai_organization if self.openai_organization else None, | ||||
| }} | }} | ||||
| @handle_llm_exceptions | @handle_llm_exceptions | ||||
| def generate( | def generate( | ||||
| self, prompts: List[str], stop: Optional[List[str]] = None | |||||
| self, | |||||
| prompts: List[str], | |||||
| stop: Optional[List[str]] = None, | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs: Any, | |||||
| ) -> LLMResult: | ) -> LLMResult: | ||||
| return super().generate(prompts, stop) | |||||
| return super().generate(prompts, stop, callbacks, **kwargs) | |||||
| @handle_llm_exceptions_async | @handle_llm_exceptions_async | ||||
| async def agenerate( | async def agenerate( | ||||
| self, prompts: List[str], stop: Optional[List[str]] = None | |||||
| self, | |||||
| prompts: List[str], | |||||
| stop: Optional[List[str]] = None, | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs: Any, | |||||
| ) -> LLMResult: | ) -> LLMResult: | ||||
| return await super().agenerate(prompts, stop) | |||||
| return await super().agenerate(prompts, stop, callbacks, **kwargs) |
| from typing import Any, List, Dict | from typing import Any, List, Dict | ||||
| from langchain.memory.chat_memory import BaseChatMemory | from langchain.memory.chat_memory import BaseChatMemory | ||||
| from langchain.schema import get_buffer_string, BaseMessage, BaseLanguageModel | |||||
| from langchain.schema import get_buffer_string | |||||
| from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ | from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ | ||||
| ReadOnlyConversationTokenDBBufferSharedMemory | ReadOnlyConversationTokenDBBufferSharedMemory |
| from llama_index import QueryKeywordExtractPrompt | |||||
| CONVERSATION_TITLE_PROMPT = ( | CONVERSATION_TITLE_PROMPT = ( | ||||
| "Human:{query}\n-----\n" | "Human:{query}\n-----\n" | ||||
| "Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n" | "Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n" | ||||
| "[\"question1\",\"question2\",\"question3\"]\n" | "[\"question1\",\"question2\",\"question3\"]\n" | ||||
| ) | ) | ||||
| QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL = ( | |||||
| "A question is provided below. Given the question, extract up to {max_keywords} " | |||||
| "keywords from the text. Focus on extracting the keywords that we can use " | |||||
| "to best lookup answers to the question. Avoid stopwords." | |||||
| "I am not sure which language the following question is in. " | |||||
| "If the user asked the question in Chinese, please return the keywords in Chinese. " | |||||
| "If the user asked the question in English, please return the keywords in English.\n" | |||||
| "---------------------\n" | |||||
| "{question}\n" | |||||
| "---------------------\n" | |||||
| "Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'\n" | |||||
| ) | |||||
| QUERY_KEYWORD_EXTRACT_TEMPLATE = QueryKeywordExtractPrompt( | |||||
| QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL | |||||
| ) | |||||
| RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \ | RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \ | ||||
| the model prompt that best suits the input. | the model prompt that best suits the input. | ||||
| You will be provided with the prompt, variables, and an opening statement. | You will be provided with the prompt, variables, and an opening statement. |
| from flask import current_app | |||||
| from langchain.embeddings import OpenAIEmbeddings | |||||
| from langchain.tools import BaseTool | |||||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||||
| from core.embedding.cached_embedding import CacheEmbedding | |||||
| from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig | |||||
| from core.index.vector_index.vector_index import VectorIndex | |||||
| from core.llm.llm_builder import LLMBuilder | |||||
| from models.dataset import Dataset | |||||
| class DatasetTool(BaseTool): | |||||
| """Tool for querying a Dataset.""" | |||||
| dataset: Dataset | |||||
| k: int = 2 | |||||
| def _run(self, tool_input: str) -> str: | |||||
| if self.dataset.indexing_technique == "economy": | |||||
| # use keyword table query | |||||
| kw_table_index = KeywordTableIndex( | |||||
| dataset=self.dataset, | |||||
| config=KeywordTableConfig( | |||||
| max_keywords_per_chunk=5 | |||||
| ) | |||||
| ) | |||||
| documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k}) | |||||
| else: | |||||
| model_credentials = LLMBuilder.get_model_credentials( | |||||
| tenant_id=self.dataset.tenant_id, | |||||
| model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id), | |||||
| model_name='text-embedding-ada-002' | |||||
| ) | |||||
| embeddings = CacheEmbedding(OpenAIEmbeddings( | |||||
| **model_credentials | |||||
| )) | |||||
| vector_index = VectorIndex( | |||||
| dataset=self.dataset, | |||||
| config=current_app.config, | |||||
| embeddings=embeddings | |||||
| ) | |||||
| documents = vector_index.search( | |||||
| tool_input, | |||||
| search_type='similarity', | |||||
| search_kwargs={ | |||||
| 'k': self.k | |||||
| } | |||||
| ) | |||||
| hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id) | |||||
| hit_callback.on_tool_end(documents) | |||||
| return str("\n".join([document.page_content for document in documents])) | |||||
| async def _arun(self, tool_input: str) -> str: | |||||
| model_credentials = LLMBuilder.get_model_credentials( | |||||
| tenant_id=self.dataset.tenant_id, | |||||
| model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id), | |||||
| model_name='text-embedding-ada-002' | |||||
| ) | |||||
| embeddings = CacheEmbedding(OpenAIEmbeddings( | |||||
| **model_credentials | |||||
| )) | |||||
| vector_index = VectorIndex( | |||||
| dataset=self.dataset, | |||||
| config=current_app.config, | |||||
| embeddings=embeddings | |||||
| ) | |||||
| documents = await vector_index.asearch( | |||||
| tool_input, | |||||
| search_type='similarity', | |||||
| search_kwargs={ | |||||
| 'k': 10 | |||||
| } | |||||
| ) | |||||
| hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id) | |||||
| hit_callback.on_tool_end(documents) | |||||
| return str("\n".join([document.page_content for document in documents])) |
| from typing import Optional | |||||
| from langchain.callbacks import CallbackManager | |||||
| from llama_index.langchain_helpers.agents import IndexToolConfig | |||||
| from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler | |||||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||||
| from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler | |||||
| from core.index.keyword_table_index import KeywordTableIndex | |||||
| from core.index.vector_index import VectorIndex | |||||
| from core.prompt.prompts import QUERY_KEYWORD_EXTRACT_TEMPLATE | |||||
| from core.tool.llama_index_tool import EnhanceLlamaIndexTool | |||||
| from models.dataset import Dataset | |||||
| class DatasetToolBuilder: | |||||
| @classmethod | |||||
| def build_dataset_tool(cls, dataset: Dataset, | |||||
| response_mode: str = "no_synthesizer", | |||||
| callback_handler: Optional[DatasetToolCallbackHandler] = None): | |||||
| if dataset.indexing_technique == "economy": | |||||
| # use keyword table query | |||||
| index = KeywordTableIndex(dataset=dataset).query_index | |||||
| if not index: | |||||
| return None | |||||
| query_kwargs = { | |||||
| "mode": "default", | |||||
| "response_mode": response_mode, | |||||
| "query_keyword_extract_template": QUERY_KEYWORD_EXTRACT_TEMPLATE, | |||||
| "max_keywords_per_query": 5, | |||||
| # If num_chunks_per_query is too large, | |||||
| # it will slow down the synthesis process due to multiple iterations of refinement. | |||||
| "num_chunks_per_query": 2 | |||||
| } | |||||
| else: | |||||
| index = VectorIndex(dataset=dataset).query_index | |||||
| if not index: | |||||
| return None | |||||
| query_kwargs = { | |||||
| "mode": "default", | |||||
| "response_mode": response_mode, | |||||
| # If top_k is too large, | |||||
| # it will slow down the synthesis process due to multiple iterations of refinement. | |||||
| "similarity_top_k": 2 | |||||
| } | |||||
| # fulfill description when it is empty | |||||
| description = dataset.description | |||||
| if not description: | |||||
| description = 'useful for when you want to answer queries about the ' + dataset.name | |||||
| index_tool_config = IndexToolConfig( | |||||
| index=index, | |||||
| name=f"dataset-{dataset.id}", | |||||
| description=description, | |||||
| index_query_kwargs=query_kwargs, | |||||
| tool_kwargs={ | |||||
| "callback_manager": CallbackManager([callback_handler, DifyStdOutCallbackHandler()]) | |||||
| }, | |||||
| # tool_kwargs={"return_direct": True}, | |||||
| # return_direct: Whether to return LLM results directly or process the output data with an Output Parser | |||||
| ) | |||||
| index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset.id) | |||||
| return EnhanceLlamaIndexTool.from_tool_config( | |||||
| tool_config=index_tool_config, | |||||
| callback_handler=index_callback_handler | |||||
| ) |
| from typing import Dict | |||||
| from langchain.tools import BaseTool | |||||
| from llama_index.indices.base import BaseGPTIndex | |||||
| from llama_index.langchain_helpers.agents import IndexToolConfig | |||||
| from pydantic import Field | |||||
| from core.callback_handler.index_tool_callback_handler import IndexToolCallbackHandler | |||||
| class EnhanceLlamaIndexTool(BaseTool): | |||||
| """Tool for querying a LlamaIndex.""" | |||||
| # NOTE: name/description still needs to be set | |||||
| index: BaseGPTIndex | |||||
| query_kwargs: Dict = Field(default_factory=dict) | |||||
| return_sources: bool = False | |||||
| callback_handler: IndexToolCallbackHandler | |||||
| @classmethod | |||||
| def from_tool_config(cls, tool_config: IndexToolConfig, | |||||
| callback_handler: IndexToolCallbackHandler) -> "EnhanceLlamaIndexTool": | |||||
| """Create a tool from a tool config.""" | |||||
| return_sources = tool_config.tool_kwargs.pop("return_sources", False) | |||||
| return cls( | |||||
| index=tool_config.index, | |||||
| callback_handler=callback_handler, | |||||
| name=tool_config.name, | |||||
| description=tool_config.description, | |||||
| return_sources=return_sources, | |||||
| query_kwargs=tool_config.index_query_kwargs, | |||||
| **tool_config.tool_kwargs, | |||||
| ) | |||||
| def _run(self, tool_input: str) -> str: | |||||
| response = self.index.query(tool_input, **self.query_kwargs) | |||||
| self.callback_handler.on_tool_end(response) | |||||
| return str(response) | |||||
| async def _arun(self, tool_input: str) -> str: | |||||
| response = await self.index.aquery(tool_input, **self.query_kwargs) | |||||
| self.callback_handler.on_tool_end(response) | |||||
| return str(response) |
| from abc import ABC, abstractmethod | |||||
| from typing import Optional | |||||
| from llama_index import ServiceContext, GPTVectorStoreIndex | |||||
| from llama_index.data_structs import Node | |||||
| from llama_index.vector_stores.types import VectorStore | |||||
| class BaseVectorStoreClient(ABC): | |||||
| @abstractmethod | |||||
| def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def to_index_config(self, index_id: str) -> dict: | |||||
| raise NotImplementedError | |||||
| class BaseGPTVectorStoreIndex(GPTVectorStoreIndex): | |||||
| def delete_node(self, node_id: str): | |||||
| self._vector_store.delete_node(node_id) | |||||
| def exists_by_node_id(self, node_id: str) -> bool: | |||||
| return self._vector_store.exists_by_node_id(node_id) | |||||
| class EnhanceVectorStore(ABC): | |||||
| @abstractmethod | |||||
| def delete_node(self, node_id: str): | |||||
| pass | |||||
| @abstractmethod | |||||
| def exists_by_node_id(self, node_id: str) -> bool: | |||||
| pass |
| from typing import cast, Any | |||||
| from langchain.schema import Document | |||||
| from langchain.vectorstores import Qdrant | |||||
| from qdrant_client.http.models import Filter, PointIdsList, FilterSelector | |||||
| from qdrant_client.local.qdrant_local import QdrantLocal | |||||
| class QdrantVectorStore(Qdrant): | |||||
| def del_texts(self, filter: Filter): | |||||
| if not filter: | |||||
| raise ValueError('filter must not be empty') | |||||
| self._reload_if_needed() | |||||
| self.client.delete( | |||||
| collection_name=self.collection_name, | |||||
| points_selector=FilterSelector( | |||||
| filter=filter | |||||
| ), | |||||
| ) | |||||
| def del_text(self, uuid: str) -> None: | |||||
| self._reload_if_needed() | |||||
| self.client.delete( | |||||
| collection_name=self.collection_name, | |||||
| points_selector=PointIdsList( | |||||
| points=[uuid], | |||||
| ), | |||||
| ) | |||||
| def text_exists(self, uuid: str) -> bool: | |||||
| self._reload_if_needed() | |||||
| response = self.client.retrieve( | |||||
| collection_name=self.collection_name, | |||||
| ids=[uuid] | |||||
| ) | |||||
| return len(response) > 0 | |||||
| def delete(self): | |||||
| self._reload_if_needed() | |||||
| self.client.delete_collection(collection_name=self.collection_name) | |||||
| @classmethod | |||||
| def _document_from_scored_point( | |||||
| cls, | |||||
| scored_point: Any, | |||||
| content_payload_key: str, | |||||
| metadata_payload_key: str, | |||||
| ) -> Document: | |||||
| if scored_point.payload.get('doc_id'): | |||||
| return Document( | |||||
| page_content=scored_point.payload.get(content_payload_key), | |||||
| metadata={'doc_id': scored_point.id} | |||||
| ) | |||||
| return Document( | |||||
| page_content=scored_point.payload.get(content_payload_key), | |||||
| metadata=scored_point.payload.get(metadata_payload_key) or {}, | |||||
| ) | |||||
| def _reload_if_needed(self): | |||||
| if isinstance(self.client, QdrantLocal): | |||||
| self.client = cast(QdrantLocal, self.client) | |||||
| self.client._load() |
| import os | |||||
| from typing import cast, List | |||||
| from llama_index.data_structs import Node | |||||
| from llama_index.data_structs.node_v2 import DocumentRelationship | |||||
| from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult | |||||
| from qdrant_client.http.models import Payload, Filter | |||||
| import qdrant_client | |||||
| from llama_index import ServiceContext, GPTVectorStoreIndex, GPTQdrantIndex | |||||
| from llama_index.data_structs.data_structs_v2 import QdrantIndexDict | |||||
| from llama_index.vector_stores import QdrantVectorStore | |||||
| from qdrant_client.local.qdrant_local import QdrantLocal | |||||
| from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore | |||||
| class QdrantVectorStoreClient(BaseVectorStoreClient): | |||||
| def __init__(self, url: str, api_key: str, root_path: str): | |||||
| self._client = self.init_from_config(url, api_key, root_path) | |||||
| @classmethod | |||||
| def init_from_config(cls, url: str, api_key: str, root_path: str): | |||||
| if url and url.startswith('path:'): | |||||
| path = url.replace('path:', '') | |||||
| if not os.path.isabs(path): | |||||
| path = os.path.join(root_path, path) | |||||
| return qdrant_client.QdrantClient( | |||||
| path=path | |||||
| ) | |||||
| else: | |||||
| return qdrant_client.QdrantClient( | |||||
| url=url, | |||||
| api_key=api_key, | |||||
| ) | |||||
| def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex: | |||||
| index_struct = QdrantIndexDict() | |||||
| if self._client is None: | |||||
| raise Exception("Vector client is not initialized.") | |||||
| # {"collection_name": "Gpt_index_xxx"} | |||||
| collection_name = config.get('collection_name') | |||||
| if not collection_name: | |||||
| raise Exception("collection_name cannot be None.") | |||||
| return GPTQdrantEnhanceIndex( | |||||
| service_context=service_context, | |||||
| index_struct=index_struct, | |||||
| vector_store=QdrantEnhanceVectorStore( | |||||
| client=self._client, | |||||
| collection_name=collection_name | |||||
| ) | |||||
| ) | |||||
| def to_index_config(self, index_id: str) -> dict: | |||||
| return {"collection_name": index_id} | |||||
| class GPTQdrantEnhanceIndex(GPTQdrantIndex, BaseGPTVectorStoreIndex): | |||||
| pass | |||||
| class QdrantEnhanceVectorStore(QdrantVectorStore, EnhanceVectorStore): | |||||
| def delete_node(self, node_id: str): | |||||
| """ | |||||
| Delete node from the index. | |||||
| :param node_id: node id | |||||
| """ | |||||
| from qdrant_client.http import models as rest | |||||
| self._reload_if_needed() | |||||
| self._client.delete( | |||||
| collection_name=self._collection_name, | |||||
| points_selector=rest.Filter( | |||||
| must=[ | |||||
| rest.FieldCondition( | |||||
| key="id", match=rest.MatchValue(value=node_id) | |||||
| ) | |||||
| ] | |||||
| ), | |||||
| ) | |||||
| def exists_by_node_id(self, node_id: str) -> bool: | |||||
| """ | |||||
| Get node from the index by node id. | |||||
| :param node_id: node id | |||||
| """ | |||||
| self._reload_if_needed() | |||||
| response = self._client.retrieve( | |||||
| collection_name=self._collection_name, | |||||
| ids=[node_id] | |||||
| ) | |||||
| return len(response) > 0 | |||||
| def query( | |||||
| self, | |||||
| query: VectorStoreQuery, | |||||
| ) -> VectorStoreQueryResult: | |||||
| """Query index for top k most similar nodes. | |||||
| Args: | |||||
| query (VectorStoreQuery): query | |||||
| """ | |||||
| query_embedding = cast(List[float], query.query_embedding) | |||||
| self._reload_if_needed() | |||||
| response = self._client.search( | |||||
| collection_name=self._collection_name, | |||||
| query_vector=query_embedding, | |||||
| limit=cast(int, query.similarity_top_k), | |||||
| query_filter=cast(Filter, self._build_query_filter(query)), | |||||
| with_vectors=True | |||||
| ) | |||||
| nodes = [] | |||||
| similarities = [] | |||||
| ids = [] | |||||
| for point in response: | |||||
| payload = cast(Payload, point.payload) | |||||
| node = Node( | |||||
| doc_id=str(point.id), | |||||
| text=payload.get("text"), | |||||
| embedding=point.vector, | |||||
| extra_info=payload.get("extra_info"), | |||||
| relationships={ | |||||
| DocumentRelationship.SOURCE: payload.get("doc_id", "None"), | |||||
| }, | |||||
| ) | |||||
| nodes.append(node) | |||||
| similarities.append(point.score) | |||||
| ids.append(str(point.id)) | |||||
| return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids) | |||||
| def _reload_if_needed(self): | |||||
| if isinstance(self._client._client, QdrantLocal): | |||||
| self._client._client._load() |
| from flask import Flask | |||||
| from llama_index import ServiceContext, GPTVectorStoreIndex | |||||
| from requests import ReadTimeout | |||||
| from tenacity import retry, retry_if_exception_type, stop_after_attempt | |||||
| from core.vector_store.qdrant_vector_store_client import QdrantVectorStoreClient | |||||
| from core.vector_store.weaviate_vector_store_client import WeaviateVectorStoreClient | |||||
| SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant'] | |||||
| class VectorStore: | |||||
| def __init__(self): | |||||
| self._vector_store = None | |||||
| self._client = None | |||||
| def init_app(self, app: Flask): | |||||
| if not app.config['VECTOR_STORE']: | |||||
| return | |||||
| self._vector_store = app.config['VECTOR_STORE'] | |||||
| if self._vector_store not in SUPPORTED_VECTOR_STORES: | |||||
| raise ValueError(f"Vector store {self._vector_store} is not supported.") | |||||
| if self._vector_store == 'weaviate': | |||||
| self._client = WeaviateVectorStoreClient( | |||||
| endpoint=app.config['WEAVIATE_ENDPOINT'], | |||||
| api_key=app.config['WEAVIATE_API_KEY'], | |||||
| grpc_enabled=app.config['WEAVIATE_GRPC_ENABLED'], | |||||
| batch_size=app.config['WEAVIATE_BATCH_SIZE'] | |||||
| ) | |||||
| elif self._vector_store == 'qdrant': | |||||
| self._client = QdrantVectorStoreClient( | |||||
| url=app.config['QDRANT_URL'], | |||||
| api_key=app.config['QDRANT_API_KEY'], | |||||
| root_path=app.root_path | |||||
| ) | |||||
| app.extensions['vector_store'] = self | |||||
| @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3)) | |||||
| def get_index(self, service_context: ServiceContext, index_struct: dict) -> GPTVectorStoreIndex: | |||||
| vector_store_config: dict = index_struct.get('vector_store') | |||||
| index = self.get_client().get_index( | |||||
| service_context=service_context, | |||||
| config=vector_store_config | |||||
| ) | |||||
| return index | |||||
| def to_index_struct(self, index_id: str) -> dict: | |||||
| return { | |||||
| "type": self._vector_store, | |||||
| "vector_store": self.get_client().to_index_config(index_id) | |||||
| } | |||||
| def get_client(self): | |||||
| if not self._client: | |||||
| raise Exception("Vector store client is not initialized.") | |||||
| return self._client |
| from llama_index.indices.query.base import IS | |||||
| from typing import ( | |||||
| Any, | |||||
| Dict, | |||||
| List, | |||||
| Optional | |||||
| ) | |||||
| from llama_index.docstore import BaseDocumentStore | |||||
| from llama_index.indices.postprocessor.node import ( | |||||
| BaseNodePostprocessor, | |||||
| ) | |||||
| from llama_index.indices.vector_store import GPTVectorStoreIndexQuery | |||||
| from llama_index.indices.response.response_builder import ResponseMode | |||||
| from llama_index.indices.service_context import ServiceContext | |||||
| from llama_index.optimization.optimizer import BaseTokenUsageOptimizer | |||||
| from llama_index.prompts.prompts import ( | |||||
| QuestionAnswerPrompt, | |||||
| RefinePrompt, | |||||
| SimpleInputPrompt, | |||||
| ) | |||||
| from core.index.query.synthesizer import EnhanceResponseSynthesizer | |||||
| class EnhanceGPTVectorStoreIndexQuery(GPTVectorStoreIndexQuery): | |||||
| @classmethod | |||||
| def from_args( | |||||
| cls, | |||||
| index_struct: IS, | |||||
| service_context: ServiceContext, | |||||
| docstore: Optional[BaseDocumentStore] = None, | |||||
| node_postprocessors: Optional[List[BaseNodePostprocessor]] = None, | |||||
| verbose: bool = False, | |||||
| # response synthesizer args | |||||
| response_mode: ResponseMode = ResponseMode.DEFAULT, | |||||
| text_qa_template: Optional[QuestionAnswerPrompt] = None, | |||||
| refine_template: Optional[RefinePrompt] = None, | |||||
| simple_template: Optional[SimpleInputPrompt] = None, | |||||
| response_kwargs: Optional[Dict] = None, | |||||
| use_async: bool = False, | |||||
| streaming: bool = False, | |||||
| optimizer: Optional[BaseTokenUsageOptimizer] = None, | |||||
| # class-specific args | |||||
| **kwargs: Any, | |||||
| ) -> "BaseGPTIndexQuery": | |||||
| response_synthesizer = EnhanceResponseSynthesizer.from_args( | |||||
| service_context=service_context, | |||||
| text_qa_template=text_qa_template, | |||||
| refine_template=refine_template, | |||||
| simple_template=simple_template, | |||||
| response_mode=response_mode, | |||||
| response_kwargs=response_kwargs, | |||||
| use_async=use_async, | |||||
| streaming=streaming, | |||||
| optimizer=optimizer, | |||||
| ) | |||||
| return cls( | |||||
| index_struct=index_struct, | |||||
| service_context=service_context, | |||||
| response_synthesizer=response_synthesizer, | |||||
| docstore=docstore, | |||||
| node_postprocessors=node_postprocessors, | |||||
| verbose=verbose, | |||||
| **kwargs, | |||||
| ) |
| from langchain.vectorstores import Weaviate | |||||
| class WeaviateVectorStore(Weaviate): | |||||
| def del_texts(self, where_filter: dict): | |||||
| if not where_filter: | |||||
| raise ValueError('where_filter must not be empty') | |||||
| self._client.batch.delete_objects( | |||||
| class_name=self._index_name, | |||||
| where=where_filter, | |||||
| output='minimal' | |||||
| ) | |||||
| def del_text(self, uuid: str) -> None: | |||||
| self._client.data_object.delete( | |||||
| uuid, | |||||
| class_name=self._index_name | |||||
| ) | |||||
| def text_exists(self, uuid: str) -> bool: | |||||
| result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({ | |||||
| "path": ["doc_id"], | |||||
| "operator": "Equal", | |||||
| "valueText": uuid, | |||||
| }).with_limit(1).do() | |||||
| if "errors" in result: | |||||
| raise ValueError(f"Error during query: {result['errors']}") | |||||
| entries = result["data"]["Get"][self._index_name] | |||||
| if len(entries) == 0: | |||||
| return False | |||||
| return True | |||||
| def delete(self): | |||||
| self._client.schema.delete_class(self._index_name) |
| import json | |||||
| import weaviate | |||||
| from dataclasses import field | |||||
| from typing import List, Any, Dict, Optional | |||||
| from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore | |||||
| from llama_index import ServiceContext, GPTWeaviateIndex, GPTVectorStoreIndex | |||||
| from llama_index.data_structs.data_structs_v2 import WeaviateIndexDict, Node | |||||
| from llama_index.data_structs.node_v2 import DocumentRelationship | |||||
| from llama_index.readers.weaviate.client import _class_name, NODE_SCHEMA, _logger | |||||
| from llama_index.vector_stores import WeaviateVectorStore | |||||
| from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult, VectorStoreQueryMode | |||||
| from llama_index.readers.weaviate.utils import ( | |||||
| parse_get_response, | |||||
| validate_client, | |||||
| ) | |||||
| class WeaviateVectorStoreClient(BaseVectorStoreClient): | |||||
| def __init__(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int): | |||||
| self._client = self.init_from_config(endpoint, api_key, grpc_enabled, batch_size) | |||||
| def init_from_config(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int): | |||||
| auth_config = weaviate.auth.AuthApiKey(api_key=api_key) | |||||
| weaviate.connect.connection.has_grpc = grpc_enabled | |||||
| client = weaviate.Client( | |||||
| url=endpoint, | |||||
| auth_client_secret=auth_config, | |||||
| timeout_config=(5, 60), | |||||
| startup_period=None | |||||
| ) | |||||
| client.batch.configure( | |||||
| # `batch_size` takes an `int` value to enable auto-batching | |||||
| # (`None` is used for manual batching) | |||||
| batch_size=batch_size, | |||||
| # dynamically update the `batch_size` based on import speed | |||||
| dynamic=True, | |||||
| # `timeout_retries` takes an `int` value to retry on time outs | |||||
| timeout_retries=3, | |||||
| ) | |||||
| return client | |||||
| def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex: | |||||
| index_struct = WeaviateIndexDict() | |||||
| if self._client is None: | |||||
| raise Exception("Vector client is not initialized.") | |||||
| # {"class_prefix": "Gpt_index_xxx"} | |||||
| class_prefix = config.get('class_prefix') | |||||
| if not class_prefix: | |||||
| raise Exception("class_prefix cannot be None.") | |||||
| return GPTWeaviateEnhanceIndex( | |||||
| service_context=service_context, | |||||
| index_struct=index_struct, | |||||
| vector_store=WeaviateWithSimilaritiesVectorStore( | |||||
| weaviate_client=self._client, | |||||
| class_prefix=class_prefix | |||||
| ) | |||||
| ) | |||||
| def to_index_config(self, index_id: str) -> dict: | |||||
| return {"class_prefix": index_id} | |||||
| class WeaviateWithSimilaritiesVectorStore(WeaviateVectorStore, EnhanceVectorStore): | |||||
| def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult: | |||||
| """Query index for top k most similar nodes.""" | |||||
| nodes = self.weaviate_query( | |||||
| self._client, | |||||
| self._class_prefix, | |||||
| query, | |||||
| ) | |||||
| nodes = nodes[: query.similarity_top_k] | |||||
| node_idxs = [str(i) for i in range(len(nodes))] | |||||
| similarities = [] | |||||
| for node in nodes: | |||||
| similarities.append(node.extra_info['similarity']) | |||||
| del node.extra_info['similarity'] | |||||
| return VectorStoreQueryResult(nodes=nodes, ids=node_idxs, similarities=similarities) | |||||
| def weaviate_query( | |||||
| self, | |||||
| client: Any, | |||||
| class_prefix: str, | |||||
| query_spec: VectorStoreQuery, | |||||
| ) -> List[Node]: | |||||
| """Convert to LlamaIndex list.""" | |||||
| validate_client(client) | |||||
| class_name = _class_name(class_prefix) | |||||
| prop_names = [p["name"] for p in NODE_SCHEMA] | |||||
| vector = query_spec.query_embedding | |||||
| # build query | |||||
| query = client.query.get(class_name, prop_names).with_additional(["id", "vector", "certainty"]) | |||||
| if query_spec.mode == VectorStoreQueryMode.DEFAULT: | |||||
| _logger.debug("Using vector search") | |||||
| if vector is not None: | |||||
| query = query.with_near_vector( | |||||
| { | |||||
| "vector": vector, | |||||
| } | |||||
| ) | |||||
| elif query_spec.mode == VectorStoreQueryMode.HYBRID: | |||||
| _logger.debug(f"Using hybrid search with alpha {query_spec.alpha}") | |||||
| query = query.with_hybrid( | |||||
| query=query_spec.query_str, | |||||
| alpha=query_spec.alpha, | |||||
| vector=vector, | |||||
| ) | |||||
| query = query.with_limit(query_spec.similarity_top_k) | |||||
| _logger.debug(f"Using limit of {query_spec.similarity_top_k}") | |||||
| # execute query | |||||
| query_result = query.do() | |||||
| # parse results | |||||
| parsed_result = parse_get_response(query_result) | |||||
| entries = parsed_result[class_name] | |||||
| results = [self._to_node(entry) for entry in entries] | |||||
| return results | |||||
| def _to_node(self, entry: Dict) -> Node: | |||||
| """Convert to Node.""" | |||||
| extra_info_str = entry["extra_info"] | |||||
| if extra_info_str == "": | |||||
| extra_info = None | |||||
| else: | |||||
| extra_info = json.loads(extra_info_str) | |||||
| if 'certainty' in entry['_additional']: | |||||
| if extra_info: | |||||
| extra_info['similarity'] = entry['_additional']['certainty'] | |||||
| else: | |||||
| extra_info = {'similarity': entry['_additional']['certainty']} | |||||
| node_info_str = entry["node_info"] | |||||
| if node_info_str == "": | |||||
| node_info = None | |||||
| else: | |||||
| node_info = json.loads(node_info_str) | |||||
| relationships_str = entry["relationships"] | |||||
| relationships: Dict[DocumentRelationship, str] | |||||
| if relationships_str == "": | |||||
| relationships = field(default_factory=dict) | |||||
| else: | |||||
| relationships = { | |||||
| DocumentRelationship(k): v for k, v in json.loads(relationships_str).items() | |||||
| } | |||||
| return Node( | |||||
| text=entry["text"], | |||||
| doc_id=entry["doc_id"], | |||||
| embedding=entry["_additional"]["vector"], | |||||
| extra_info=extra_info, | |||||
| node_info=node_info, | |||||
| relationships=relationships, | |||||
| ) | |||||
| def delete(self, doc_id: str, **delete_kwargs: Any) -> None: | |||||
| """Delete a document. | |||||
| Args: | |||||
| doc_id (str): document id | |||||
| """ | |||||
| delete_document(self._client, doc_id, self._class_prefix) | |||||
| def delete_node(self, node_id: str): | |||||
| """ | |||||
| Delete node from the index. | |||||
| :param node_id: node id | |||||
| """ | |||||
| delete_node(self._client, node_id, self._class_prefix) | |||||
| def exists_by_node_id(self, node_id: str) -> bool: | |||||
| """ | |||||
| Get node from the index by node id. | |||||
| :param node_id: node id | |||||
| """ | |||||
| entry = get_by_node_id(self._client, node_id, self._class_prefix) | |||||
| return True if entry else False | |||||
| class GPTWeaviateEnhanceIndex(GPTWeaviateIndex, BaseGPTVectorStoreIndex): | |||||
| pass | |||||
| def delete_document(client: Any, ref_doc_id: str, class_prefix: str) -> None: | |||||
| """Delete entry.""" | |||||
| validate_client(client) | |||||
| # make sure that each entry | |||||
| class_name = _class_name(class_prefix) | |||||
| where_filter = { | |||||
| "path": ["ref_doc_id"], | |||||
| "operator": "Equal", | |||||
| "valueString": ref_doc_id, | |||||
| } | |||||
| query = ( | |||||
| client.query.get(class_name).with_additional(["id"]).with_where(where_filter) | |||||
| ) | |||||
| query_result = query.do() | |||||
| parsed_result = parse_get_response(query_result) | |||||
| entries = parsed_result[class_name] | |||||
| for entry in entries: | |||||
| client.data_object.delete(entry["_additional"]["id"], class_name) | |||||
| while len(entries) > 0: | |||||
| query_result = query.do() | |||||
| parsed_result = parse_get_response(query_result) | |||||
| entries = parsed_result[class_name] | |||||
| for entry in entries: | |||||
| client.data_object.delete(entry["_additional"]["id"], class_name) | |||||
| def delete_node(client: Any, node_id: str, class_prefix: str) -> None: | |||||
| """Delete entry.""" | |||||
| validate_client(client) | |||||
| # make sure that each entry | |||||
| class_name = _class_name(class_prefix) | |||||
| where_filter = { | |||||
| "path": ["doc_id"], | |||||
| "operator": "Equal", | |||||
| "valueString": node_id, | |||||
| } | |||||
| query = ( | |||||
| client.query.get(class_name).with_additional(["id"]).with_where(where_filter) | |||||
| ) | |||||
| query_result = query.do() | |||||
| parsed_result = parse_get_response(query_result) | |||||
| entries = parsed_result[class_name] | |||||
| for entry in entries: | |||||
| client.data_object.delete(entry["_additional"]["id"], class_name) | |||||
| def get_by_node_id(client: Any, node_id: str, class_prefix: str) -> Optional[Dict]: | |||||
| """Delete entry.""" | |||||
| validate_client(client) | |||||
| # make sure that each entry | |||||
| class_name = _class_name(class_prefix) | |||||
| where_filter = { | |||||
| "path": ["doc_id"], | |||||
| "operator": "Equal", | |||||
| "valueString": node_id, | |||||
| } | |||||
| query = ( | |||||
| client.query.get(class_name).with_additional(["id"]).with_where(where_filter) | |||||
| ) | |||||
| query_result = query.do() | |||||
| parsed_result = parse_get_response(query_result) | |||||
| entries = parsed_result[class_name] | |||||
| if len(entries) == 0: | |||||
| return None | |||||
| return entries[0] |
| from core.vector_store.vector_store import VectorStore | |||||
| vector_store = VectorStore() | |||||
| def init_app(app): | |||||
| vector_store.init_app(app) |
| import subprocess | import subprocess | ||||
| import uuid | import uuid | ||||
| from datetime import datetime | from datetime import datetime | ||||
| from hashlib import sha256 | |||||
| from zoneinfo import available_timezones | from zoneinfo import available_timezones | ||||
| import random | import random | ||||
| import string | import string | ||||
| return request.headers.getlist("X-Forwarded-For")[0] | return request.headers.getlist("X-Forwarded-For")[0] | ||||
| else: | else: | ||||
| return request.remote_addr | return request.remote_addr | ||||
| def generate_text_hash(text: str) -> str: | |||||
| hash_text = str(text) + 'None' | |||||
| return sha256(hash_text.encode()).hexdigest() |
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | ||||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | ||||
| _current_tenant: db.Model = None | |||||
| @property | @property | ||||
| def current_tenant(self): | def current_tenant(self): | ||||
| return self._current_tenant | return self._current_tenant |
| def document_count(self): | def document_count(self): | ||||
| return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar() | return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar() | ||||
| @property | |||||
| def available_document_count(self): | |||||
| return db.session.query(func.count(Document.id)).filter( | |||||
| Document.dataset_id == self.id, | |||||
| Document.indexing_status == 'completed', | |||||
| Document.enabled == True, | |||||
| Document.archived == False | |||||
| ).scalar() | |||||
| @property | |||||
| def available_segment_count(self): | |||||
| return db.session.query(func.count(DocumentSegment.id)).filter( | |||||
| DocumentSegment.dataset_id == self.id, | |||||
| DocumentSegment.status == 'completed', | |||||
| DocumentSegment.enabled == True | |||||
| ).scalar() | |||||
| @property | @property | ||||
| def word_count(self): | def word_count(self): | ||||
| return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \ | return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \ | ||||
| @property | @property | ||||
| def dataset(self): | def dataset(self): | ||||
| return Dataset.query.get(self.dataset_id) | |||||
| return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none() | |||||
| @property | @property | ||||
| def segment_count(self): | def segment_count(self): | ||||
| @property | @property | ||||
| def keyword_table_dict(self): | def keyword_table_dict(self): | ||||
| return json.loads(self.keyword_table) if self.keyword_table else None | |||||
| class SetDecoder(json.JSONDecoder): | |||||
| def __init__(self, *args, **kwargs): | |||||
| super().__init__(object_hook=self.object_hook, *args, **kwargs) | |||||
| def object_hook(self, dct): | |||||
| if isinstance(dct, dict): | |||||
| for keyword, node_idxs in dct.items(): | |||||
| if isinstance(node_idxs, list): | |||||
| dct[keyword] = set(node_idxs) | |||||
| return dct | |||||
| return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None | |||||
| class Embedding(db.Model): | class Embedding(db.Model): |
| beautifulsoup4==4.12.2 | beautifulsoup4==4.12.2 | ||||
| flask~=2.3.2 | flask~=2.3.2 | ||||
| Flask-SQLAlchemy~=3.0.3 | Flask-SQLAlchemy~=3.0.3 | ||||
| SQLAlchemy~=1.4.28 | |||||
| flask-login==0.6.2 | flask-login==0.6.2 | ||||
| flask-migrate~=4.0.4 | flask-migrate~=4.0.4 | ||||
| flask-restful==0.3.9 | flask-restful==0.3.9 | ||||
| flask-cors==3.0.10 | flask-cors==3.0.10 | ||||
| gunicorn~=20.1.0 | gunicorn~=20.1.0 | ||||
| gevent~=22.10.2 | gevent~=22.10.2 | ||||
| langchain==0.0.142 | |||||
| llama-index==0.5.27 | |||||
| langchain==0.0.209 | |||||
| openai~=0.27.5 | openai~=0.27.5 | ||||
| psycopg2-binary~=2.9.6 | psycopg2-binary~=2.9.6 | ||||
| pycryptodome==3.17 | pycryptodome==3.17 | ||||
| jieba==0.42.1 | jieba==0.42.1 | ||||
| celery==5.2.7 | celery==5.2.7 | ||||
| redis~=4.5.4 | redis~=4.5.4 | ||||
| pypdf==3.8.1 | |||||
| openpyxl==3.1.2 | openpyxl==3.1.2 | ||||
| chardet~=5.1.0 | |||||
| chardet~=5.1.0 | |||||
| docx2txt==0.8 | |||||
| pypdfium2==4.16.0 |
| from core.constant import llm_constant | from core.constant import llm_constant | ||||
| from models.account import Account | from models.account import Account | ||||
| from services.dataset_service import DatasetService | from services.dataset_service import DatasetService | ||||
| from services.errors.account import NoPermissionError | |||||
| class AppModelConfigService: | class AppModelConfigService: |
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from flask_login import current_user | from flask_login import current_user | ||||
| from core.index.index_builder import IndexBuilder | |||||
| from events.dataset_event import dataset_was_deleted | from events.dataset_event import dataset_was_deleted | ||||
| from events.document_event import document_was_deleted | from events.document_event import document_was_deleted | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| dataset.indexing_technique = document_data["indexing_technique"] | dataset.indexing_technique = document_data["indexing_technique"] | ||||
| if dataset.indexing_technique == 'high_quality': | |||||
| IndexBuilder.get_default_service_context(dataset.tenant_id) | |||||
| documents = [] | documents = [] | ||||
| batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) | batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) | ||||
| if 'original_document_id' in document_data and document_data["original_document_id"]: | if 'original_document_id' in document_data and document_data["original_document_id"]: |
| from typing import List | from typing import List | ||||
| import numpy as np | import numpy as np | ||||
| from llama_index.data_structs.node_v2 import NodeWithScore | |||||
| from llama_index.indices.query.schema import QueryBundle | |||||
| from llama_index.indices.vector_store import GPTVectorStoreIndexQuery | |||||
| from flask import current_app | |||||
| from langchain.embeddings import OpenAIEmbeddings | |||||
| from langchain.embeddings.base import Embeddings | |||||
| from langchain.schema import Document | |||||
| from sklearn.manifold import TSNE | from sklearn.manifold import TSNE | ||||
| from core.docstore.empty_docstore import EmptyDocumentStore | |||||
| from core.index.vector_index import VectorIndex | |||||
| from core.embedding.cached_embedding import CacheEmbedding | |||||
| from core.index.vector_index.vector_index import VectorIndex | |||||
| from core.llm.llm_builder import LLMBuilder | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.account import Account | from models.account import Account | ||||
| from models.dataset import Dataset, DocumentSegment, DatasetQuery | from models.dataset import Dataset, DocumentSegment, DatasetQuery | ||||
| from services.errors.index import IndexNotInitializedError | |||||
| class HitTestingService: | class HitTestingService: | ||||
| @classmethod | @classmethod | ||||
| def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict: | def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict: | ||||
| index = VectorIndex(dataset=dataset).query_index | |||||
| if not index: | |||||
| raise IndexNotInitializedError() | |||||
| index_query = GPTVectorStoreIndexQuery( | |||||
| index_struct=index.index_struct, | |||||
| service_context=index.service_context, | |||||
| vector_store=index.query_context.get('vector_store'), | |||||
| docstore=EmptyDocumentStore(), | |||||
| response_synthesizer=None, | |||||
| similarity_top_k=limit | |||||
| ) | |||||
| if dataset.available_document_count == 0 or dataset.available_document_count == 0: | |||||
| return { | |||||
| "query": { | |||||
| "content": query, | |||||
| "tsne_position": {'x': 0, 'y': 0}, | |||||
| }, | |||||
| "records": [] | |||||
| } | |||||
| query_bundle = QueryBundle( | |||||
| query_str=query, | |||||
| custom_embedding_strs=[query], | |||||
| model_credentials = LLMBuilder.get_model_credentials( | |||||
| tenant_id=dataset.tenant_id, | |||||
| model_provider=LLMBuilder.get_default_provider(dataset.tenant_id), | |||||
| model_name='text-embedding-ada-002' | |||||
| ) | ) | ||||
| query_bundle.embedding = index.service_context.embed_model.get_agg_embedding_from_queries( | |||||
| query_bundle.embedding_strs | |||||
| embeddings = CacheEmbedding(OpenAIEmbeddings( | |||||
| **model_credentials | |||||
| )) | |||||
| vector_index = VectorIndex( | |||||
| dataset=dataset, | |||||
| config=current_app.config, | |||||
| embeddings=embeddings | |||||
| ) | ) | ||||
| start = time.perf_counter() | start = time.perf_counter() | ||||
| nodes = index_query.retrieve(query_bundle=query_bundle) | |||||
| documents = vector_index.search( | |||||
| query, | |||||
| search_type='similarity_score_threshold', | |||||
| search_kwargs={ | |||||
| 'k': 10 | |||||
| } | |||||
| ) | |||||
| end = time.perf_counter() | end = time.perf_counter() | ||||
| logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") | logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") | ||||
| db.session.add(dataset_query) | db.session.add(dataset_query) | ||||
| db.session.commit() | db.session.commit() | ||||
| return cls.compact_retrieve_response(dataset, query_bundle, nodes) | |||||
| return cls.compact_retrieve_response(dataset, embeddings, query, documents) | |||||
| @classmethod | @classmethod | ||||
| def compact_retrieve_response(cls, dataset: Dataset, query_bundle: QueryBundle, nodes: List[NodeWithScore]): | |||||
| embeddings = [ | |||||
| query_bundle.embedding | |||||
| def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: List[Document]): | |||||
| text_embeddings = [ | |||||
| embeddings.embed_query(query) | |||||
| ] | ] | ||||
| for node in nodes: | |||||
| embeddings.append(node.node.embedding) | |||||
| text_embeddings.extend(embeddings.embed_documents([document.page_content for document in documents])) | |||||
| tsne_position_data = cls.get_tsne_positions_from_embeddings(embeddings) | |||||
| tsne_position_data = cls.get_tsne_positions_from_embeddings(text_embeddings) | |||||
| query_position = tsne_position_data.pop(0) | query_position = tsne_position_data.pop(0) | ||||
| i = 0 | i = 0 | ||||
| records = [] | records = [] | ||||
| for node in nodes: | |||||
| index_node_id = node.node.doc_id | |||||
| for document in documents: | |||||
| index_node_id = document.metadata['doc_id'] | |||||
| segment = db.session.query(DocumentSegment).filter( | segment = db.session.query(DocumentSegment).filter( | ||||
| DocumentSegment.dataset_id == dataset.id, | DocumentSegment.dataset_id == dataset.id, | ||||
| record = { | record = { | ||||
| "segment": segment, | "segment": segment, | ||||
| "score": node.score, | |||||
| "score": document.metadata['score'], | |||||
| "tsne_position": tsne_position_data[i] | "tsne_position": tsne_position_data[i] | ||||
| } | } | ||||
| return { | return { | ||||
| "query": { | "query": { | ||||
| "content": query_bundle.query_str, | |||||
| "content": query, | |||||
| "tsne_position": query_position, | "tsne_position": query_position, | ||||
| }, | }, | ||||
| "records": records | "records": records |
| import click | import click | ||||
| from celery import shared_task | from celery import shared_task | ||||
| from llama_index.data_structs import Node | |||||
| from llama_index.data_structs.node_v2 import DocumentRelationship | |||||
| from langchain.schema import Document | |||||
| from werkzeug.exceptions import NotFound | from werkzeug.exceptions import NotFound | ||||
| from core.index.keyword_table_index import KeywordTableIndex | |||||
| from core.index.vector_index import VectorIndex | |||||
| from core.index.index import IndexBuilder | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from models.dataset import DocumentSegment, Document | |||||
| from models.dataset import DocumentSegment | |||||
| from models.dataset import Document as DatasetDocument | |||||
| @shared_task | @shared_task | ||||
| def add_document_to_index_task(document_id: str): | |||||
| def add_document_to_index_task(dataset_document_id: str): | |||||
| """ | """ | ||||
| Async Add document to index | Async Add document to index | ||||
| :param document_id: | :param document_id: | ||||
| Usage: add_document_to_index.delay(document_id) | Usage: add_document_to_index.delay(document_id) | ||||
| """ | """ | ||||
| logging.info(click.style('Start add document to index: {}'.format(document_id), fg='green')) | |||||
| logging.info(click.style('Start add document to index: {}'.format(dataset_document_id), fg='green')) | |||||
| start_at = time.perf_counter() | start_at = time.perf_counter() | ||||
| document = db.session.query(Document).filter(Document.id == document_id).first() | |||||
| if not document: | |||||
| dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document_id).first() | |||||
| if not dataset_document: | |||||
| raise NotFound('Document not found') | raise NotFound('Document not found') | ||||
| if document.indexing_status != 'completed': | |||||
| if dataset_document.indexing_status != 'completed': | |||||
| return | return | ||||
| indexing_cache_key = 'document_{}_indexing'.format(document.id) | |||||
| indexing_cache_key = 'document_{}_indexing'.format(dataset_document.id) | |||||
| try: | try: | ||||
| segments = db.session.query(DocumentSegment).filter( | segments = db.session.query(DocumentSegment).filter( | ||||
| DocumentSegment.document_id == document.id, | |||||
| DocumentSegment.document_id == dataset_document.id, | |||||
| DocumentSegment.enabled == True | DocumentSegment.enabled == True | ||||
| ) \ | ) \ | ||||
| .order_by(DocumentSegment.position.asc()).all() | .order_by(DocumentSegment.position.asc()).all() | ||||
| nodes = [] | |||||
| previous_node = None | |||||
| documents = [] | |||||
| for segment in segments: | for segment in segments: | ||||
| relationships = { | |||||
| DocumentRelationship.SOURCE: document.id | |||||
| } | |||||
| if previous_node: | |||||
| relationships[DocumentRelationship.PREVIOUS] = previous_node.doc_id | |||||
| previous_node.relationships[DocumentRelationship.NEXT] = segment.index_node_id | |||||
| node = Node( | |||||
| doc_id=segment.index_node_id, | |||||
| doc_hash=segment.index_node_hash, | |||||
| text=segment.content, | |||||
| extra_info=None, | |||||
| node_info=None, | |||||
| relationships=relationships | |||||
| document = Document( | |||||
| page_content=segment.content, | |||||
| metadata={ | |||||
| "doc_id": segment.index_node_id, | |||||
| "doc_hash": segment.index_node_hash, | |||||
| "document_id": segment.document_id, | |||||
| "dataset_id": segment.dataset_id, | |||||
| } | |||||
| ) | ) | ||||
| previous_node = node | |||||
| documents.append(document) | |||||
| nodes.append(node) | |||||
| dataset = document.dataset | |||||
| dataset = dataset_document.dataset | |||||
| if not dataset: | if not dataset: | ||||
| raise Exception('Document has no dataset') | raise Exception('Document has no dataset') | ||||
| vector_index = VectorIndex(dataset=dataset) | |||||
| keyword_table_index = KeywordTableIndex(dataset=dataset) | |||||
| # save vector index | # save vector index | ||||
| if dataset.indexing_technique == "high_quality": | |||||
| vector_index.add_nodes( | |||||
| nodes=nodes, | |||||
| duplicate_check=True | |||||
| ) | |||||
| index = IndexBuilder.get_index(dataset, 'high_quality') | |||||
| if index: | |||||
| index.add_texts(documents) | |||||
| # save keyword index | # save keyword index | ||||
| keyword_table_index.add_nodes(nodes) | |||||
| index = IndexBuilder.get_index(dataset, 'economy') | |||||
| if index: | |||||
| index.add_texts(documents) | |||||
| end_at = time.perf_counter() | end_at = time.perf_counter() | ||||
| logging.info( | logging.info( | ||||
| click.style('Document added to index: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) | |||||
| click.style('Document added to index: {} latency: {}'.format(dataset_document.id, end_at - start_at), fg='green')) | |||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception("add document to index failed") | logging.exception("add document to index failed") | ||||
| document.enabled = False | |||||
| document.disabled_at = datetime.datetime.utcnow() | |||||
| document.status = 'error' | |||||
| document.error = str(e) | |||||
| dataset_document.enabled = False | |||||
| dataset_document.disabled_at = datetime.datetime.utcnow() | |||||
| dataset_document.status = 'error' | |||||
| dataset_document.error = str(e) | |||||
| db.session.commit() | db.session.commit() | ||||
| finally: | finally: | ||||
| redis_client.delete(indexing_cache_key) | redis_client.delete(indexing_cache_key) |
| import click | import click | ||||
| from celery import shared_task | from celery import shared_task | ||||
| from llama_index.data_structs import Node | |||||
| from llama_index.data_structs.node_v2 import DocumentRelationship | |||||
| from langchain.schema import Document | |||||
| from werkzeug.exceptions import NotFound | from werkzeug.exceptions import NotFound | ||||
| from core.index.keyword_table_index import KeywordTableIndex | |||||
| from core.index.vector_index import VectorIndex | |||||
| from core.index.index import IndexBuilder | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from models.dataset import DocumentSegment | from models.dataset import DocumentSegment | ||||
| indexing_cache_key = 'segment_{}_indexing'.format(segment.id) | indexing_cache_key = 'segment_{}_indexing'.format(segment.id) | ||||
| try: | try: | ||||
| relationships = { | |||||
| DocumentRelationship.SOURCE: segment.document_id, | |||||
| } | |||||
| previous_segment = segment.previous_segment | |||||
| if previous_segment: | |||||
| relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id | |||||
| next_segment = segment.next_segment | |||||
| if next_segment: | |||||
| relationships[DocumentRelationship.NEXT] = next_segment.index_node_id | |||||
| node = Node( | |||||
| doc_id=segment.index_node_id, | |||||
| doc_hash=segment.index_node_hash, | |||||
| text=segment.content, | |||||
| extra_info=None, | |||||
| node_info=None, | |||||
| relationships=relationships | |||||
| document = Document( | |||||
| page_content=segment.content, | |||||
| metadata={ | |||||
| "doc_id": segment.index_node_id, | |||||
| "doc_hash": segment.index_node_hash, | |||||
| "document_id": segment.document_id, | |||||
| "dataset_id": segment.dataset_id, | |||||
| } | |||||
| ) | ) | ||||
| dataset = segment.dataset | dataset = segment.dataset | ||||
| if not dataset: | if not dataset: | ||||
| raise Exception('Segment has no dataset') | |||||
| logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) | |||||
| return | |||||
| vector_index = VectorIndex(dataset=dataset) | |||||
| keyword_table_index = KeywordTableIndex(dataset=dataset) | |||||
| dataset_document = segment.document | |||||
| if not dataset_document: | |||||
| logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) | |||||
| return | |||||
| if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': | |||||
| logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) | |||||
| return | |||||
| # save vector index | # save vector index | ||||
| if dataset.indexing_technique == "high_quality": | |||||
| vector_index.add_nodes( | |||||
| nodes=[node], | |||||
| duplicate_check=True | |||||
| ) | |||||
| index = IndexBuilder.get_index(dataset, 'high_quality') | |||||
| if index: | |||||
| index.add_texts([document], duplicate_check=True) | |||||
| # save keyword index | # save keyword index | ||||
| keyword_table_index.add_nodes([node]) | |||||
| index = IndexBuilder.get_index(dataset, 'economy') | |||||
| if index: | |||||
| index.add_texts([document]) | |||||
| end_at = time.perf_counter() | end_at = time.perf_counter() | ||||
| logging.info(click.style('Segment added to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) | logging.info(click.style('Segment added to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) |
| import click | import click | ||||
| from celery import shared_task | from celery import shared_task | ||||
| from core.index.keyword_table_index import KeywordTableIndex | |||||
| from core.index.vector_index import VectorIndex | |||||
| from core.index.index import IndexBuilder | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \ | from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \ | ||||
| AppDatasetJoin | AppDatasetJoin | ||||
| index_struct=index_struct | index_struct=index_struct | ||||
| ) | ) | ||||
| vector_index = VectorIndex(dataset=dataset) | |||||
| keyword_table_index = KeywordTableIndex(dataset=dataset) | |||||
| documents = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() | documents = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() | ||||
| index_doc_ids = [document.id for document in documents] | |||||
| segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() | segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() | ||||
| index_node_ids = [segment.index_node_id for segment in segments] | |||||
| # delete from vector index | |||||
| if dataset.indexing_technique == "high_quality": | |||||
| for index_doc_id in index_doc_ids: | |||||
| try: | |||||
| vector_index.del_doc(index_doc_id) | |||||
| except Exception: | |||||
| logging.exception("Delete doc index failed when dataset deleted.") | |||||
| continue | |||||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||||
| kw_index = IndexBuilder.get_index(dataset, 'economy') | |||||
| # delete from keyword index | |||||
| if index_node_ids: | |||||
| # delete from vector index | |||||
| if vector_index: | |||||
| try: | try: | ||||
| keyword_table_index.del_nodes(index_node_ids) | |||||
| vector_index.delete() | |||||
| except Exception: | except Exception: | ||||
| logging.exception("Delete nodes index failed when dataset deleted.") | |||||
| logging.exception("Delete doc index failed when dataset deleted.") | |||||
| # delete from keyword index | |||||
| try: | |||||
| kw_index.delete() | |||||
| except Exception: | |||||
| logging.exception("Delete nodes index failed when dataset deleted.") | |||||
| for document in documents: | for document in documents: | ||||
| db.session.delete(document) | db.session.delete(document) | ||||
| for segment in segments: | for segment in segments: | ||||
| db.session.delete(segment) | db.session.delete(segment) | ||||
| db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == dataset_id).delete() | |||||
| db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete() | db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete() | ||||
| db.session.query(DatasetQuery).filter(DatasetQuery.dataset_id == dataset_id).delete() | db.session.query(DatasetQuery).filter(DatasetQuery.dataset_id == dataset_id).delete() | ||||
| db.session.query(AppDatasetJoin).filter(AppDatasetJoin.dataset_id == dataset_id).delete() | db.session.query(AppDatasetJoin).filter(AppDatasetJoin.dataset_id == dataset_id).delete() |
| import click | import click | ||||
| from celery import shared_task | from celery import shared_task | ||||
| from core.index.keyword_table_index import KeywordTableIndex | |||||
| from core.index.vector_index import VectorIndex | |||||
| from core.index.index import IndexBuilder | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import DocumentSegment, Dataset | from models.dataset import DocumentSegment, Dataset | ||||
| if not dataset: | if not dataset: | ||||
| raise Exception('Document has no dataset') | raise Exception('Document has no dataset') | ||||
| vector_index = VectorIndex(dataset=dataset) | |||||
| keyword_table_index = KeywordTableIndex(dataset=dataset) | |||||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||||
| kw_index = IndexBuilder.get_index(dataset, 'economy') | |||||
| segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() | segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() | ||||
| index_node_ids = [segment.index_node_id for segment in segments] | index_node_ids = [segment.index_node_id for segment in segments] | ||||
| # delete from vector index | # delete from vector index | ||||
| vector_index.del_nodes(index_node_ids) | |||||
| if vector_index: | |||||
| vector_index.delete_by_document_id(document_id) | |||||
| # delete from keyword index | # delete from keyword index | ||||
| if index_node_ids: | if index_node_ids: | ||||
| keyword_table_index.del_nodes(index_node_ids) | |||||
| kw_index.delete_by_ids(index_node_ids) | |||||
| for segment in segments: | for segment in segments: | ||||
| db.session.delete(segment) | db.session.delete(segment) | ||||
| db.session.commit() | db.session.commit() | ||||
| end_at = time.perf_counter() | end_at = time.perf_counter() | ||||
| logging.info( | logging.info( |
| import click | import click | ||||
| from celery import shared_task | from celery import shared_task | ||||
| from core.index.keyword_table_index import KeywordTableIndex | |||||
| from core.index.vector_index import VectorIndex | |||||
| from core.index.index import IndexBuilder | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import DocumentSegment, Dataset, Document | from models.dataset import DocumentSegment, Dataset, Document | ||||
| if not dataset: | if not dataset: | ||||
| raise Exception('Document has no dataset') | raise Exception('Document has no dataset') | ||||
| vector_index = VectorIndex(dataset=dataset) | |||||
| keyword_table_index = KeywordTableIndex(dataset=dataset) | |||||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||||
| kw_index = IndexBuilder.get_index(dataset, 'economy') | |||||
| for document_id in document_ids: | for document_id in document_ids: | ||||
| document = db.session.query(Document).filter( | document = db.session.query(Document).filter( | ||||
| Document.id == document_id | Document.id == document_id | ||||
| ).first() | ).first() | ||||
| db.session.delete(document) | db.session.delete(document) | ||||
| segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() | segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() | ||||
| index_node_ids = [segment.index_node_id for segment in segments] | index_node_ids = [segment.index_node_id for segment in segments] | ||||
| # delete from vector index | # delete from vector index | ||||
| vector_index.del_nodes(index_node_ids) | |||||
| if vector_index: | |||||
| vector_index.delete_by_document_id(document_id) | |||||
| # delete from keyword index | # delete from keyword index | ||||
| if index_node_ids: | if index_node_ids: | ||||
| keyword_table_index.del_nodes(index_node_ids) | |||||
| kw_index.delete_by_ids(index_node_ids) | |||||
| for segment in segments: | for segment in segments: | ||||
| db.session.delete(segment) | db.session.delete(segment) |
| import click | import click | ||||
| from celery import shared_task | from celery import shared_task | ||||
| from llama_index.data_structs.node_v2 import DocumentRelationship, Node | |||||
| from core.index.vector_index import VectorIndex | |||||
| from langchain.schema import Document | |||||
| from core.index.index import IndexBuilder | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import DocumentSegment, Document, Dataset | |||||
| from models.dataset import DocumentSegment, Dataset | |||||
| from models.dataset import Document as DatasetDocument | |||||
| @shared_task | @shared_task | ||||
| dataset = Dataset.query.filter_by( | dataset = Dataset.query.filter_by( | ||||
| id=dataset_id | id=dataset_id | ||||
| ).first() | ).first() | ||||
| if not dataset: | if not dataset: | ||||
| raise Exception('Dataset not found') | raise Exception('Dataset not found') | ||||
| documents = Document.query.filter_by(dataset_id=dataset_id).all() | |||||
| if documents: | |||||
| vector_index = VectorIndex(dataset=dataset) | |||||
| for document in documents: | |||||
| # delete from vector index | |||||
| if action == "remove": | |||||
| vector_index.del_doc(document.id) | |||||
| elif action == "add": | |||||
| if action == "remove": | |||||
| index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True) | |||||
| index.delete() | |||||
| elif action == "add": | |||||
| dataset_documents = db.session.query(DatasetDocument).filter( | |||||
| DatasetDocument.dataset_id == dataset_id, | |||||
| DatasetDocument.indexing_status == 'completed', | |||||
| DatasetDocument.enabled == True, | |||||
| DatasetDocument.archived == False, | |||||
| ).all() | |||||
| if dataset_documents: | |||||
| # save vector index | |||||
| index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True) | |||||
| for dataset_document in dataset_documents: | |||||
| # delete from vector index | |||||
| segments = db.session.query(DocumentSegment).filter( | segments = db.session.query(DocumentSegment).filter( | ||||
| DocumentSegment.document_id == document.id, | |||||
| DocumentSegment.document_id == dataset_document.id, | |||||
| DocumentSegment.enabled == True | DocumentSegment.enabled == True | ||||
| ) .order_by(DocumentSegment.position.asc()).all() | ) .order_by(DocumentSegment.position.asc()).all() | ||||
| nodes = [] | |||||
| previous_node = None | |||||
| documents = [] | |||||
| for segment in segments: | for segment in segments: | ||||
| relationships = { | |||||
| DocumentRelationship.SOURCE: document.id | |||||
| } | |||||
| if previous_node: | |||||
| relationships[DocumentRelationship.PREVIOUS] = previous_node.doc_id | |||||
| previous_node.relationships[DocumentRelationship.NEXT] = segment.index_node_id | |||||
| node = Node( | |||||
| doc_id=segment.index_node_id, | |||||
| doc_hash=segment.index_node_hash, | |||||
| text=segment.content, | |||||
| extra_info=None, | |||||
| node_info=None, | |||||
| relationships=relationships | |||||
| document = Document( | |||||
| page_content=segment.content, | |||||
| metadata={ | |||||
| "doc_id": segment.index_node_id, | |||||
| "doc_hash": segment.index_node_hash, | |||||
| "document_id": segment.document_id, | |||||
| "dataset_id": segment.dataset_id, | |||||
| } | |||||
| ) | ) | ||||
| previous_node = node | |||||
| nodes.append(node) | |||||
| documents.append(document) | |||||
| # save vector index | # save vector index | ||||
| vector_index.add_nodes( | |||||
| nodes=nodes, | |||||
| duplicate_check=True | |||||
| ) | |||||
| index.add_texts(documents) | |||||
| end_at = time.perf_counter() | end_at = time.perf_counter() | ||||
| logging.info( | logging.info( |
| from celery import shared_task | from celery import shared_task | ||||
| from werkzeug.exceptions import NotFound | from werkzeug.exceptions import NotFound | ||||
| from core.data_source.notion import NotionPageReader | |||||
| from core.index.keyword_table_index import KeywordTableIndex | |||||
| from core.index.vector_index import VectorIndex | |||||
| from core.data_loader.loader.notion import NotionLoader | |||||
| from core.index.index import IndexBuilder | |||||
| from core.indexing_runner import IndexingRunner, DocumentIsPausedException | from core.indexing_runner import IndexingRunner, DocumentIsPausedException | ||||
| from core.llm.error import ProviderTokenNotInitError | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import Document, Dataset, DocumentSegment | from models.dataset import Document, Dataset, DocumentSegment | ||||
| from models.source import DataSourceBinding | from models.source import DataSourceBinding | ||||
| raise ValueError("no notion page found") | raise ValueError("no notion page found") | ||||
| workspace_id = data_source_info['notion_workspace_id'] | workspace_id = data_source_info['notion_workspace_id'] | ||||
| page_id = data_source_info['notion_page_id'] | page_id = data_source_info['notion_page_id'] | ||||
| page_type = data_source_info['type'] | |||||
| page_edited_time = data_source_info['last_edited_time'] | page_edited_time = data_source_info['last_edited_time'] | ||||
| data_source_binding = DataSourceBinding.query.filter( | data_source_binding = DataSourceBinding.query.filter( | ||||
| db.and_( | db.and_( | ||||
| ).first() | ).first() | ||||
| if not data_source_binding: | if not data_source_binding: | ||||
| raise ValueError('Data source binding not found.') | raise ValueError('Data source binding not found.') | ||||
| reader = NotionPageReader(integration_token=data_source_binding.access_token) | |||||
| last_edited_time = reader.get_page_last_edited_time(page_id) | |||||
| loader = NotionLoader( | |||||
| notion_access_token=data_source_binding.access_token, | |||||
| notion_workspace_id=workspace_id, | |||||
| notion_obj_id=page_id, | |||||
| notion_page_type=page_type | |||||
| ) | |||||
| last_edited_time = loader.get_notion_last_edited_time() | |||||
| # check the page is updated | # check the page is updated | ||||
| if last_edited_time != page_edited_time: | if last_edited_time != page_edited_time: | ||||
| document.indexing_status = 'parsing' | document.indexing_status = 'parsing' | ||||
| if not dataset: | if not dataset: | ||||
| raise Exception('Dataset not found') | raise Exception('Dataset not found') | ||||
| vector_index = VectorIndex(dataset=dataset) | |||||
| keyword_table_index = KeywordTableIndex(dataset=dataset) | |||||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||||
| kw_index = IndexBuilder.get_index(dataset, 'economy') | |||||
| segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() | segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() | ||||
| index_node_ids = [segment.index_node_id for segment in segments] | index_node_ids = [segment.index_node_id for segment in segments] | ||||
| # delete from vector index | # delete from vector index | ||||
| vector_index.del_nodes(index_node_ids) | |||||
| if vector_index: | |||||
| vector_index.delete_by_document_id(document_id) | |||||
| # delete from keyword index | # delete from keyword index | ||||
| if index_node_ids: | if index_node_ids: | ||||
| keyword_table_index.del_nodes(index_node_ids) | |||||
| kw_index.delete_by_ids(index_node_ids) | |||||
| for segment in segments: | for segment in segments: | ||||
| db.session.delete(segment) | db.session.delete(segment) | ||||
| click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) | click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) | ||||
| except Exception: | except Exception: | ||||
| logging.exception("Cleaned document when document update data source or process rule failed") | logging.exception("Cleaned document when document update data source or process rule failed") | ||||
| try: | try: | ||||
| indexing_runner = IndexingRunner() | indexing_runner = IndexingRunner() | ||||
| indexing_runner.run([document]) | indexing_runner.run([document]) | ||||
| end_at = time.perf_counter() | end_at = time.perf_counter() | ||||
| logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) | logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) | ||||
| except DocumentIsPausedException: | |||||
| logging.info(click.style('Document update paused, document id: {}'.format(document.id), fg='yellow')) | |||||
| except ProviderTokenNotInitError as e: | |||||
| document.indexing_status = 'error' | |||||
| document.error = str(e.description) | |||||
| document.stopped_at = datetime.datetime.utcnow() | |||||
| db.session.commit() | |||||
| except Exception as e: | |||||
| logging.exception("consume update document failed") | |||||
| document.indexing_status = 'error' | |||||
| document.error = str(e) | |||||
| document.stopped_at = datetime.datetime.utcnow() | |||||
| db.session.commit() | |||||
| except DocumentIsPausedException as ex: | |||||
| logging.info(click.style(str(ex), fg='yellow')) | |||||
| except Exception: | |||||
| pass |
| from werkzeug.exceptions import NotFound | from werkzeug.exceptions import NotFound | ||||
| from core.indexing_runner import IndexingRunner, DocumentIsPausedException | from core.indexing_runner import IndexingRunner, DocumentIsPausedException | ||||
| from core.llm.error import ProviderTokenNotInitError | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import Document | from models.dataset import Document | ||||
| Usage: document_indexing_task.delay(dataset_id, document_id) | Usage: document_indexing_task.delay(dataset_id, document_id) | ||||
| """ | """ | ||||
| documents = [] | documents = [] | ||||
| start_at = time.perf_counter() | |||||
| for document_id in document_ids: | for document_id in document_ids: | ||||
| logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) | logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) | ||||
| start_at = time.perf_counter() | |||||
| document = db.session.query(Document).filter( | document = db.session.query(Document).filter( | ||||
| Document.id == document_id, | Document.id == document_id, | ||||
| indexing_runner = IndexingRunner() | indexing_runner = IndexingRunner() | ||||
| indexing_runner.run(documents) | indexing_runner.run(documents) | ||||
| end_at = time.perf_counter() | end_at = time.perf_counter() | ||||
| logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) | |||||
| except DocumentIsPausedException: | |||||
| logging.info(click.style('Document paused, document id: {}'.format(document.id), fg='yellow')) | |||||
| except ProviderTokenNotInitError as e: | |||||
| document.indexing_status = 'error' | |||||
| document.error = str(e.description) | |||||
| document.stopped_at = datetime.datetime.utcnow() | |||||
| db.session.commit() | |||||
| except Exception as e: | |||||
| logging.exception("consume document failed") | |||||
| document.indexing_status = 'error' | |||||
| document.error = str(e) | |||||
| document.stopped_at = datetime.datetime.utcnow() | |||||
| db.session.commit() | |||||
| logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) | |||||
| except DocumentIsPausedException as ex: | |||||
| logging.info(click.style(str(ex), fg='yellow')) | |||||
| except Exception: | |||||
| pass |
| from celery import shared_task | from celery import shared_task | ||||
| from werkzeug.exceptions import NotFound | from werkzeug.exceptions import NotFound | ||||
| from core.index.keyword_table_index import KeywordTableIndex | |||||
| from core.index.vector_index import VectorIndex | |||||
| from core.index.index import IndexBuilder | |||||
| from core.indexing_runner import IndexingRunner, DocumentIsPausedException | from core.indexing_runner import IndexingRunner, DocumentIsPausedException | ||||
| from core.llm.error import ProviderTokenNotInitError | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import Document, Dataset, DocumentSegment | from models.dataset import Document, Dataset, DocumentSegment | ||||
| if not dataset: | if not dataset: | ||||
| raise Exception('Dataset not found') | raise Exception('Dataset not found') | ||||
| vector_index = VectorIndex(dataset=dataset) | |||||
| keyword_table_index = KeywordTableIndex(dataset=dataset) | |||||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||||
| kw_index = IndexBuilder.get_index(dataset, 'economy') | |||||
| segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() | segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() | ||||
| index_node_ids = [segment.index_node_id for segment in segments] | index_node_ids = [segment.index_node_id for segment in segments] | ||||
| # delete from vector index | # delete from vector index | ||||
| vector_index.del_nodes(index_node_ids) | |||||
| if vector_index: | |||||
| vector_index.delete_by_ids(index_node_ids) | |||||
| # delete from keyword index | # delete from keyword index | ||||
| if index_node_ids: | if index_node_ids: | ||||
| keyword_table_index.del_nodes(index_node_ids) | |||||
| kw_index.delete_by_ids(index_node_ids) | |||||
| for segment in segments: | for segment in segments: | ||||
| db.session.delete(segment) | db.session.delete(segment) | ||||
| click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) | click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) | ||||
| except Exception: | except Exception: | ||||
| logging.exception("Cleaned document when document update data source or process rule failed") | logging.exception("Cleaned document when document update data source or process rule failed") | ||||
| try: | try: | ||||
| indexing_runner = IndexingRunner() | indexing_runner = IndexingRunner() | ||||
| indexing_runner.run([document]) | indexing_runner.run([document]) | ||||
| end_at = time.perf_counter() | end_at = time.perf_counter() | ||||
| logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) | logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) | ||||
| except DocumentIsPausedException: | |||||
| logging.info(click.style('Document update paused, document id: {}'.format(document.id), fg='yellow')) | |||||
| except ProviderTokenNotInitError as e: | |||||
| document.indexing_status = 'error' | |||||
| document.error = str(e.description) | |||||
| document.stopped_at = datetime.datetime.utcnow() | |||||
| db.session.commit() | |||||
| except Exception as e: | |||||
| logging.exception("consume update document failed") | |||||
| document.indexing_status = 'error' | |||||
| document.error = str(e) | |||||
| document.stopped_at = datetime.datetime.utcnow() | |||||
| db.session.commit() | |||||
| except DocumentIsPausedException as ex: | |||||
| logging.info(click.style(str(ex), fg='yellow')) | |||||
| except Exception: | |||||
| pass |
| import datetime | |||||
| import logging | import logging | ||||
| import time | import time | ||||
| indexing_runner.run_in_indexing_status(document) | indexing_runner.run_in_indexing_status(document) | ||||
| end_at = time.perf_counter() | end_at = time.perf_counter() | ||||
| logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) | logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) | ||||
| except DocumentIsPausedException: | |||||
| logging.info(click.style('Document paused, document id: {}'.format(document.id), fg='yellow')) | |||||
| except Exception as e: | |||||
| logging.exception("consume document failed") | |||||
| document.indexing_status = 'error' | |||||
| document.error = str(e) | |||||
| document.stopped_at = datetime.datetime.utcnow() | |||||
| db.session.commit() | |||||
| except DocumentIsPausedException as ex: | |||||
| logging.info(click.style(str(ex), fg='yellow')) | |||||
| except Exception: | |||||
| pass |
| from celery import shared_task | from celery import shared_task | ||||
| from werkzeug.exceptions import NotFound | from werkzeug.exceptions import NotFound | ||||
| from core.index.keyword_table_index import KeywordTableIndex | |||||
| from core.index.vector_index import VectorIndex | |||||
| from core.index.index import IndexBuilder | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from models.dataset import DocumentSegment, Document | from models.dataset import DocumentSegment, Document | ||||
| if not dataset: | if not dataset: | ||||
| raise Exception('Document has no dataset') | raise Exception('Document has no dataset') | ||||
| vector_index = VectorIndex(dataset=dataset) | |||||
| keyword_table_index = KeywordTableIndex(dataset=dataset) | |||||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||||
| kw_index = IndexBuilder.get_index(dataset, 'economy') | |||||
| # delete from vector index | # delete from vector index | ||||
| vector_index.del_doc(document.id) | |||||
| vector_index.delete_by_document_id(document.id) | |||||
| # delete from keyword index | # delete from keyword index | ||||
| segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all() | segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all() | ||||
| index_node_ids = [segment.index_node_id for segment in segments] | index_node_ids = [segment.index_node_id for segment in segments] | ||||
| if index_node_ids: | if index_node_ids: | ||||
| keyword_table_index.del_nodes(index_node_ids) | |||||
| kw_index.delete_by_ids(index_node_ids) | |||||
| end_at = time.perf_counter() | end_at = time.perf_counter() | ||||
| logging.info( | logging.info( |
| from celery import shared_task | from celery import shared_task | ||||
| from werkzeug.exceptions import NotFound | from werkzeug.exceptions import NotFound | ||||
| from core.index.keyword_table_index import KeywordTableIndex | |||||
| from core.index.vector_index import VectorIndex | |||||
| from core.index.index import IndexBuilder | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from models.dataset import DocumentSegment | from models.dataset import DocumentSegment | ||||
| dataset = segment.dataset | dataset = segment.dataset | ||||
| if not dataset: | if not dataset: | ||||
| raise Exception('Segment has no dataset') | |||||
| logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) | |||||
| return | |||||
| vector_index = VectorIndex(dataset=dataset) | |||||
| keyword_table_index = KeywordTableIndex(dataset=dataset) | |||||
| dataset_document = segment.document | |||||
| if not dataset_document: | |||||
| logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) | |||||
| return | |||||
| if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': | |||||
| logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) | |||||
| return | |||||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||||
| kw_index = IndexBuilder.get_index(dataset, 'economy') | |||||
| # delete from vector index | # delete from vector index | ||||
| if dataset.indexing_technique == "high_quality": | |||||
| vector_index.del_nodes([segment.index_node_id]) | |||||
| if vector_index: | |||||
| vector_index.delete_by_ids([segment.index_node_id]) | |||||
| # delete from keyword index | # delete from keyword index | ||||
| keyword_table_index.del_nodes([segment.index_node_id]) | |||||
| kw_index.delete_by_ids([segment.index_node_id]) | |||||
| end_at = time.perf_counter() | end_at = time.perf_counter() | ||||
| logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) | logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) |