Co-authored-by: jyong <jyong@dify.ai>tags/0.3.31
| @@ -8,6 +8,8 @@ import time | |||
| import uuid | |||
| import click | |||
| import qdrant_client | |||
| from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType | |||
| from tqdm import tqdm | |||
| from flask import current_app, Flask | |||
| from langchain.embeddings import OpenAIEmbeddings | |||
| @@ -484,6 +486,38 @@ def normalization_collections(): | |||
| click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green')) | |||
| @click.command('add-qdrant-full-text-index', help='add qdrant full text index') | |||
| def add_qdrant_full_text_index(): | |||
| click.echo(click.style('Start add full text index.', fg='green')) | |||
| binds = db.session.query(DatasetCollectionBinding).all() | |||
| if binds and current_app.config['VECTOR_STORE'] == 'qdrant': | |||
| qdrant_url = current_app.config['QDRANT_URL'] | |||
| qdrant_api_key = current_app.config['QDRANT_API_KEY'] | |||
| client = qdrant_client.QdrantClient( | |||
| qdrant_url, | |||
| api_key=qdrant_api_key, # For Qdrant Cloud, None for local instance | |||
| ) | |||
| for bind in binds: | |||
| try: | |||
| text_index_params = TextIndexParams( | |||
| type=TextIndexType.TEXT, | |||
| tokenizer=TokenizerType.MULTILINGUAL, | |||
| min_token_len=2, | |||
| max_token_len=20, | |||
| lowercase=True | |||
| ) | |||
| client.create_payload_index(bind.collection_name, 'page_content', | |||
| field_schema=text_index_params) | |||
| except Exception as e: | |||
| click.echo( | |||
| click.style('Create full text index error: {} {}'.format(e.__class__.__name__, str(e)), | |||
| fg='red')) | |||
| click.echo( | |||
| click.style( | |||
| 'Congratulations! add collection {} full text index successful.'.format(bind.collection_name), | |||
| fg='green')) | |||
| def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list): | |||
| with flask_app.app_context(): | |||
| try: | |||
| @@ -647,10 +681,10 @@ def update_app_model_configs(batch_size): | |||
| pbar.update(len(data_batch)) | |||
| @click.command('migrate_default_input_to_dataset_query_variable') | |||
| @click.option("--batch-size", default=500, help="Number of records to migrate in each batch.") | |||
| def migrate_default_input_to_dataset_query_variable(batch_size): | |||
| click.secho("Starting...", fg='green') | |||
| total_records = db.session.query(AppModelConfig) \ | |||
| @@ -658,13 +692,13 @@ def migrate_default_input_to_dataset_query_variable(batch_size): | |||
| .filter(App.mode == 'completion') \ | |||
| .filter(AppModelConfig.dataset_query_variable == None) \ | |||
| .count() | |||
| if total_records == 0: | |||
| click.secho("No data to migrate.", fg='green') | |||
| return | |||
| num_batches = (total_records + batch_size - 1) // batch_size | |||
| with tqdm(total=total_records, desc="Migrating Data") as pbar: | |||
| for i in range(num_batches): | |||
| offset = i * batch_size | |||
| @@ -697,14 +731,14 @@ def migrate_default_input_to_dataset_query_variable(batch_size): | |||
| for form in user_input_form: | |||
| paragraph = form.get('paragraph') | |||
| if paragraph \ | |||
| and paragraph.get('variable') == 'query': | |||
| data.dataset_query_variable = 'query' | |||
| break | |||
| and paragraph.get('variable') == 'query': | |||
| data.dataset_query_variable = 'query' | |||
| break | |||
| if paragraph \ | |||
| and paragraph.get('variable') == 'default_input': | |||
| data.dataset_query_variable = 'default_input' | |||
| break | |||
| and paragraph.get('variable') == 'default_input': | |||
| data.dataset_query_variable = 'default_input' | |||
| break | |||
| db.session.commit() | |||
| @@ -712,7 +746,7 @@ def migrate_default_input_to_dataset_query_variable(batch_size): | |||
| click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}", | |||
| fg='red') | |||
| continue | |||
| click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green') | |||
| pbar.update(len(data_batch)) | |||
| @@ -731,3 +765,4 @@ def register_commands(app): | |||
| app.cli.add_command(update_app_model_configs) | |||
| app.cli.add_command(normalization_collections) | |||
| app.cli.add_command(migrate_default_input_to_dataset_query_variable) | |||
| app.cli.add_command(add_qdrant_full_text_index) | |||
| @@ -170,6 +170,7 @@ class DatasetApi(Resource): | |||
| help='Invalid indexing technique.') | |||
| parser.add_argument('permission', type=str, location='json', choices=( | |||
| 'only_me', 'all_team_members'), help='Invalid permission.') | |||
| parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.') | |||
| args = parser.parse_args() | |||
| # The role of the current user in the ta table must be admin or owner | |||
| @@ -401,6 +402,7 @@ class DatasetApiKeyApi(Resource): | |||
| class DatasetApiDeleteApi(Resource): | |||
| resource_type = 'dataset' | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -436,6 +438,50 @@ class DatasetApiBaseUrlApi(Resource): | |||
| } | |||
| class DatasetRetrievalSettingApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| vector_type = current_app.config['VECTOR_STORE'] | |||
| if vector_type == 'milvus': | |||
| return { | |||
| 'retrieval_method': [ | |||
| 'semantic_search' | |||
| ] | |||
| } | |||
| elif vector_type == 'qdrant' or vector_type == 'weaviate': | |||
| return { | |||
| 'retrieval_method': [ | |||
| 'semantic_search', 'full_text_search', 'hybrid_search' | |||
| ] | |||
| } | |||
| else: | |||
| raise ValueError("Unsupported vector db type.") | |||
| class DatasetRetrievalSettingMockApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, vector_type): | |||
| if vector_type == 'milvus': | |||
| return { | |||
| 'retrieval_method': [ | |||
| 'semantic_search' | |||
| ] | |||
| } | |||
| elif vector_type == 'qdrant' or vector_type == 'weaviate': | |||
| return { | |||
| 'retrieval_method': [ | |||
| 'semantic_search', 'full_text_search', 'hybrid_search' | |||
| ] | |||
| } | |||
| else: | |||
| raise ValueError("Unsupported vector db type.") | |||
| api.add_resource(DatasetListApi, '/datasets') | |||
| api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>') | |||
| api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries') | |||
| @@ -445,3 +491,5 @@ api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing | |||
| api.add_resource(DatasetApiKeyApi, '/datasets/api-keys') | |||
| api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>') | |||
| api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') | |||
| api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') | |||
| api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>') | |||
| @@ -221,6 +221,8 @@ class DatasetDocumentListApi(Resource): | |||
| parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | |||
| parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, | |||
| location='json') | |||
| parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, | |||
| location='json') | |||
| args = parser.parse_args() | |||
| if not dataset.indexing_technique and not args['indexing_technique']: | |||
| @@ -263,6 +265,8 @@ class DatasetInitApi(Resource): | |||
| parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | |||
| parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, | |||
| location='json') | |||
| parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, | |||
| location='json') | |||
| args = parser.parse_args() | |||
| if args['indexing_technique'] == 'high_quality': | |||
| try: | |||
| @@ -42,19 +42,18 @@ class HitTestingApi(Resource): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('query', type=str, location='json') | |||
| parser.add_argument('retrieval_model', type=dict, required=False, location='json') | |||
| args = parser.parse_args() | |||
| query = args['query'] | |||
| if not query or len(query) > 250: | |||
| raise ValueError('Query is required and cannot exceed 250 characters') | |||
| HitTestingService.hit_testing_args_check(args) | |||
| try: | |||
| response = HitTestingService.retrieve( | |||
| dataset=dataset, | |||
| query=query, | |||
| query=args['query'], | |||
| account=current_user, | |||
| limit=10, | |||
| retrieval_model=args['retrieval_model'], | |||
| limit=10 | |||
| ) | |||
| return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)} | |||
| @@ -19,7 +19,7 @@ class DefaultModelApi(Resource): | |||
| def get(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=['text-generation', 'embeddings', 'speech2text'], location='args') | |||
| choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args') | |||
| args = parser.parse_args() | |||
| tenant_id = current_user.current_tenant_id | |||
| @@ -71,19 +71,18 @@ class DefaultModelApi(Resource): | |||
| @account_initialization_required | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model_name', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=['text-generation', 'embeddings', 'speech2text'], location='json') | |||
| parser.add_argument('provider_name', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json') | |||
| args = parser.parse_args() | |||
| provider_service = ProviderService() | |||
| provider_service.update_default_model_of_model_type( | |||
| tenant_id=current_user.current_tenant_id, | |||
| model_type=args['model_type'], | |||
| provider_name=args['provider_name'], | |||
| model_name=args['model_name'] | |||
| ) | |||
| model_settings = args['model_settings'] | |||
| for model_setting in model_settings: | |||
| provider_service.update_default_model_of_model_type( | |||
| tenant_id=current_user.current_tenant_id, | |||
| model_type=model_setting['model_type'], | |||
| provider_name=model_setting['provider_name'], | |||
| model_name=model_setting['model_name'] | |||
| ) | |||
| return {'result': 'success'} | |||
| @@ -36,6 +36,8 @@ class DocumentAddByTextApi(DatasetApiResource): | |||
| location='json') | |||
| parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, | |||
| location='json') | |||
| parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, | |||
| location='json') | |||
| args = parser.parse_args() | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| @@ -95,6 +97,8 @@ class DocumentUpdateByTextApi(DatasetApiResource): | |||
| parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | |||
| parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, | |||
| location='json') | |||
| parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, | |||
| location='json') | |||
| args = parser.parse_args() | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| @@ -14,7 +14,6 @@ from pydantic import root_validator | |||
| from core.model_providers.models.entity.message import to_prompt_messages | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.third_party.langchain.llms.fake import FakeLLM | |||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | |||
| class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||
| @@ -60,7 +59,6 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||
| return AgentFinish(return_values={"output": ''}, log='') | |||
| elif len(self.tools) == 1: | |||
| tool = next(iter(self.tools)) | |||
| tool = cast(DatasetRetrieverTool, tool) | |||
| rst = tool.run(tool_input={'query': kwargs['input']}) | |||
| # output = '' | |||
| # rst_json = json.loads(rst) | |||
| @@ -0,0 +1,158 @@ | |||
| import json | |||
| from typing import Tuple, List, Any, Union, Sequence, Optional, cast | |||
| from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent | |||
| from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message | |||
| from langchain.callbacks.base import BaseCallbackManager | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.prompts.chat import BaseMessagePromptTemplate | |||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage | |||
| from langchain.schema.language_model import BaseLanguageModel | |||
| from langchain.tools import BaseTool | |||
| from pydantic import root_validator | |||
| from core.model_providers.models.entity.message import to_prompt_messages | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.third_party.langchain.llms.fake import FakeLLM | |||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | |||
| class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||
| """ | |||
| An Multi Dataset Retrieve Agent driven by Router. | |||
| """ | |||
| model_instance: BaseLLM | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| arbitrary_types_allowed = True | |||
| @root_validator | |||
| def validate_llm(cls, values: dict) -> dict: | |||
| return values | |||
| def should_use_agent(self, query: str): | |||
| """ | |||
| return should use agent | |||
| :param query: | |||
| :return: | |||
| """ | |||
| return True | |||
| def plan( | |||
| self, | |||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> Union[AgentAction, AgentFinish]: | |||
| """Given input, decided what to do. | |||
| Args: | |||
| intermediate_steps: Steps the LLM has taken to date, along with observations | |||
| **kwargs: User inputs. | |||
| Returns: | |||
| Action specifying what tool to use. | |||
| """ | |||
| if len(self.tools) == 0: | |||
| return AgentFinish(return_values={"output": ''}, log='') | |||
| elif len(self.tools) == 1: | |||
| tool = next(iter(self.tools)) | |||
| tool = cast(DatasetRetrieverTool, tool) | |||
| rst = tool.run(tool_input={'query': kwargs['input']}) | |||
| # output = '' | |||
| # rst_json = json.loads(rst) | |||
| # for item in rst_json: | |||
| # output += f'{item["content"]}\n' | |||
| return AgentFinish(return_values={"output": rst}, log=rst) | |||
| if intermediate_steps: | |||
| _, observation = intermediate_steps[-1] | |||
| return AgentFinish(return_values={"output": observation}, log=observation) | |||
| try: | |||
| agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs) | |||
| if isinstance(agent_decision, AgentAction): | |||
| tool_inputs = agent_decision.tool_input | |||
| if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs: | |||
| tool_inputs['query'] = kwargs['input'] | |||
| agent_decision.tool_input = tool_inputs | |||
| else: | |||
| agent_decision.return_values['output'] = '' | |||
| return agent_decision | |||
| except Exception as e: | |||
| new_exception = self.model_instance.handle_exceptions(e) | |||
| raise new_exception | |||
| def real_plan( | |||
| self, | |||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> Union[AgentAction, AgentFinish]: | |||
| """Given input, decided what to do. | |||
| Args: | |||
| intermediate_steps: Steps the LLM has taken to date, along with observations | |||
| **kwargs: User inputs. | |||
| Returns: | |||
| Action specifying what tool to use. | |||
| """ | |||
| agent_scratchpad = _format_intermediate_steps(intermediate_steps) | |||
| selected_inputs = { | |||
| k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" | |||
| } | |||
| full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) | |||
| prompt = self.prompt.format_prompt(**full_inputs) | |||
| messages = prompt.to_messages() | |||
| prompt_messages = to_prompt_messages(messages) | |||
| result = self.model_instance.run( | |||
| messages=prompt_messages, | |||
| functions=self.functions, | |||
| ) | |||
| ai_message = AIMessage( | |||
| content=result.content, | |||
| additional_kwargs={ | |||
| 'function_call': result.function_call | |||
| } | |||
| ) | |||
| agent_decision = _parse_ai_message(ai_message) | |||
| return agent_decision | |||
| async def aplan( | |||
| self, | |||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> Union[AgentAction, AgentFinish]: | |||
| raise NotImplementedError() | |||
| @classmethod | |||
| def from_llm_and_tools( | |||
| cls, | |||
| model_instance: BaseLLM, | |||
| tools: Sequence[BaseTool], | |||
| callback_manager: Optional[BaseCallbackManager] = None, | |||
| extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, | |||
| system_message: Optional[SystemMessage] = SystemMessage( | |||
| content="You are a helpful AI assistant." | |||
| ), | |||
| **kwargs: Any, | |||
| ) -> BaseSingleActionAgent: | |||
| prompt = cls.create_prompt( | |||
| extra_prompt_messages=extra_prompt_messages, | |||
| system_message=system_message, | |||
| ) | |||
| return cls( | |||
| model_instance=model_instance, | |||
| llm=FakeLLM(response=''), | |||
| prompt=prompt, | |||
| tools=tools, | |||
| callback_manager=callback_manager, | |||
| **kwargs, | |||
| ) | |||
| @@ -89,7 +89,6 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): | |||
| return AgentFinish(return_values={"output": ''}, log='') | |||
| elif len(self.dataset_tools) == 1: | |||
| tool = next(iter(self.dataset_tools)) | |||
| tool = cast(DatasetRetrieverTool, tool) | |||
| rst = tool.run(tool_input={'query': kwargs['input']}) | |||
| return AgentFinish(return_values={"output": rst}, log=rst) | |||
| @@ -18,6 +18,7 @@ from langchain.agents import AgentExecutor as LCAgentExecutor | |||
| from core.helper import moderation | |||
| from core.model_providers.error import LLMError | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool | |||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | |||
| @@ -78,7 +79,7 @@ class AgentExecutor: | |||
| verbose=True | |||
| ) | |||
| elif self.configuration.strategy == PlanningStrategy.ROUTER: | |||
| self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] | |||
| self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)] | |||
| agent = MultiDatasetRouterAgent.from_llm_and_tools( | |||
| model_instance=self.configuration.model_instance, | |||
| tools=self.configuration.tools, | |||
| @@ -86,7 +87,7 @@ class AgentExecutor: | |||
| verbose=True | |||
| ) | |||
| elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER: | |||
| self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] | |||
| self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)] | |||
| agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools( | |||
| model_instance=self.configuration.model_instance, | |||
| tools=self.configuration.tools, | |||
| @@ -10,8 +10,7 @@ from models.dataset import DocumentSegment | |||
| class DatasetIndexToolCallbackHandler: | |||
| """Callback handler for dataset tool.""" | |||
| def __init__(self, dataset_id: str, conversation_message_task: ConversationMessageTask) -> None: | |||
| self.dataset_id = dataset_id | |||
| def __init__(self, conversation_message_task: ConversationMessageTask) -> None: | |||
| self.conversation_message_task = conversation_message_task | |||
| def on_tool_end(self, documents: List[Document]) -> None: | |||
| @@ -21,7 +20,6 @@ class DatasetIndexToolCallbackHandler: | |||
| # add hit count to document segment | |||
| db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.dataset_id == self.dataset_id, | |||
| DocumentSegment.index_node_id == doc_id | |||
| ).update( | |||
| {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, | |||
| @@ -127,6 +127,7 @@ class Completion: | |||
| memory=memory, | |||
| rest_tokens=rest_tokens_for_context_and_memory, | |||
| chain_callback=chain_callback, | |||
| tenant_id=app.tenant_id, | |||
| retriever_from=retriever_from | |||
| ) | |||
| @@ -3,7 +3,7 @@ from pathlib import Path | |||
| from typing import List, Union, Optional | |||
| import requests | |||
| from langchain.document_loaders import TextLoader, Docx2txtLoader | |||
| from langchain.document_loaders import TextLoader, Docx2txtLoader, UnstructuredFileLoader, UnstructuredAPIFileLoader | |||
| from langchain.schema import Document | |||
| from core.data_loader.loader.csv_loader import CSVLoader | |||
| @@ -20,13 +20,13 @@ USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTM | |||
| class FileExtractor: | |||
| @classmethod | |||
| def load(cls, upload_file: UploadFile, return_text: bool = False) -> Union[List[Document] | str]: | |||
| def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document] | str]: | |||
| with tempfile.TemporaryDirectory() as temp_dir: | |||
| suffix = Path(upload_file.key).suffix | |||
| file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" | |||
| storage.download(upload_file.key, file_path) | |||
| return cls.load_from_file(file_path, return_text, upload_file) | |||
| return cls.load_from_file(file_path, return_text, upload_file, is_automatic) | |||
| @classmethod | |||
| def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document] | str]: | |||
| @@ -44,24 +44,34 @@ class FileExtractor: | |||
| @classmethod | |||
| def load_from_file(cls, file_path: str, return_text: bool = False, | |||
| upload_file: Optional[UploadFile] = None) -> Union[List[Document] | str]: | |||
| upload_file: Optional[UploadFile] = None, | |||
| is_automatic: bool = False) -> Union[List[Document] | str]: | |||
| input_file = Path(file_path) | |||
| delimiter = '\n' | |||
| file_extension = input_file.suffix.lower() | |||
| if file_extension == '.xlsx': | |||
| loader = ExcelLoader(file_path) | |||
| elif file_extension == '.pdf': | |||
| loader = PdfLoader(file_path, upload_file=upload_file) | |||
| elif file_extension in ['.md', '.markdown']: | |||
| loader = MarkdownLoader(file_path, autodetect_encoding=True) | |||
| elif file_extension in ['.htm', '.html']: | |||
| loader = HTMLLoader(file_path) | |||
| elif file_extension == '.docx': | |||
| loader = Docx2txtLoader(file_path) | |||
| elif file_extension == '.csv': | |||
| loader = CSVLoader(file_path, autodetect_encoding=True) | |||
| if is_automatic: | |||
| loader = UnstructuredFileLoader( | |||
| file_path, strategy="hi_res", mode="elements" | |||
| ) | |||
| # loader = UnstructuredAPIFileLoader( | |||
| # file_path=filenames[0], | |||
| # api_key="FAKE_API_KEY", | |||
| # ) | |||
| else: | |||
| # txt | |||
| loader = TextLoader(file_path, autodetect_encoding=True) | |||
| if file_extension == '.xlsx': | |||
| loader = ExcelLoader(file_path) | |||
| elif file_extension == '.pdf': | |||
| loader = PdfLoader(file_path, upload_file=upload_file) | |||
| elif file_extension in ['.md', '.markdown']: | |||
| loader = MarkdownLoader(file_path, autodetect_encoding=True) | |||
| elif file_extension in ['.htm', '.html']: | |||
| loader = HTMLLoader(file_path) | |||
| elif file_extension == '.docx': | |||
| loader = Docx2txtLoader(file_path) | |||
| elif file_extension == '.csv': | |||
| loader = CSVLoader(file_path, autodetect_encoding=True) | |||
| else: | |||
| # txt | |||
| loader = TextLoader(file_path, autodetect_encoding=True) | |||
| return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load() | |||
| @@ -40,6 +40,13 @@ class BaseVectorIndex(BaseIndex): | |||
| def _get_vector_store_class(self) -> type: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def search_by_full_text_index( | |||
| self, query: str, | |||
| **kwargs: Any | |||
| ) -> List[Document]: | |||
| raise NotImplementedError | |||
| def search( | |||
| self, query: str, | |||
| **kwargs: Any | |||
| @@ -1,16 +1,14 @@ | |||
| from typing import Optional, cast | |||
| from typing import cast, Any, List | |||
| from langchain.embeddings.base import Embeddings | |||
| from langchain.schema import Document, BaseRetriever | |||
| from langchain.vectorstores import VectorStore, milvus | |||
| from langchain.schema import Document | |||
| from langchain.vectorstores import VectorStore | |||
| from pydantic import BaseModel, root_validator | |||
| from core.index.base import BaseIndex | |||
| from core.index.vector_index.base import BaseVectorIndex | |||
| from core.vector_store.milvus_vector_store import MilvusVectorStore | |||
| from core.vector_store.weaviate_vector_store import WeaviateVectorStore | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, DatasetCollectionBinding | |||
| from models.dataset import Dataset | |||
| class MilvusConfig(BaseModel): | |||
| @@ -74,7 +72,7 @@ class MilvusVectorIndex(BaseVectorIndex): | |||
| index_params = { | |||
| 'metric_type': 'IP', | |||
| 'index_type': "HNSW", | |||
| 'params': {"M": 8, "efConstruction": 64} | |||
| 'params': {"M": 8, "efConstruction": 64} | |||
| } | |||
| self._vector_store = MilvusVectorStore.from_documents( | |||
| texts, | |||
| @@ -152,3 +150,7 @@ class MilvusVectorIndex(BaseVectorIndex): | |||
| ), | |||
| ], | |||
| )) | |||
| def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]: | |||
| # milvus/zilliz doesn't support bm25 search | |||
| return [] | |||
| @@ -191,3 +191,21 @@ class QdrantVectorIndex(BaseVectorIndex): | |||
| return True | |||
| return False | |||
| def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| from qdrant_client.http import models | |||
| return vector_store.similarity_search_by_bm25(models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key="group_id", | |||
| match=models.MatchValue(value=self.dataset.id), | |||
| ), | |||
| models.FieldCondition( | |||
| key="page_content", | |||
| match=models.MatchText(text=query), | |||
| ) | |||
| ], | |||
| ), kwargs.get('top_k', 2)) | |||
| @@ -1,4 +1,4 @@ | |||
| from typing import Optional, cast | |||
| from typing import Optional, cast, Any, List | |||
| import requests | |||
| import weaviate | |||
| @@ -26,6 +26,7 @@ class WeaviateConfig(BaseModel): | |||
| class WeaviateVectorIndex(BaseVectorIndex): | |||
| def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings): | |||
| super().__init__(dataset, embeddings) | |||
| self._client = self._init_client(config) | |||
| @@ -148,3 +149,9 @@ class WeaviateVectorIndex(BaseVectorIndex): | |||
| return True | |||
| return False | |||
| def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs) | |||
| @@ -49,14 +49,14 @@ class IndexingRunner: | |||
| if not dataset: | |||
| raise ValueError("no dataset found") | |||
| # load file | |||
| text_docs = self._load_data(dataset_document) | |||
| # get the process rule | |||
| processing_rule = db.session.query(DatasetProcessRule). \ | |||
| filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ | |||
| first() | |||
| # load file | |||
| text_docs = self._load_data(dataset_document) | |||
| # get splitter | |||
| splitter = self._get_splitter(processing_rule) | |||
| @@ -380,7 +380,7 @@ class IndexingRunner: | |||
| "preview": preview_texts | |||
| } | |||
| def _load_data(self, dataset_document: DatasetDocument) -> List[Document]: | |||
| def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> List[Document]: | |||
| # load file | |||
| if dataset_document.data_source_type not in ["upload_file", "notion_import"]: | |||
| return [] | |||
| @@ -396,7 +396,7 @@ class IndexingRunner: | |||
| one_or_none() | |||
| if file_detail: | |||
| text_docs = FileExtractor.load(file_detail) | |||
| text_docs = FileExtractor.load(file_detail, is_automatic=False) | |||
| elif dataset_document.data_source_type == 'notion_import': | |||
| loader = NotionLoader.from_document(dataset_document) | |||
| text_docs = loader.load() | |||
| @@ -9,6 +9,7 @@ from core.model_providers.models.embedding.base import BaseEmbedding | |||
| from core.model_providers.models.entity.model_params import ModelKwargs, ModelType | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.model_providers.models.moderation.base import BaseModeration | |||
| from core.model_providers.models.reranking.base import BaseReranking | |||
| from core.model_providers.models.speech2text.base import BaseSpeech2Text | |||
| from extensions.ext_database import db | |||
| from models.provider import TenantDefaultModel | |||
| @@ -140,6 +141,44 @@ class ModelFactory: | |||
| name=model_name | |||
| ) | |||
| @classmethod | |||
| def get_reranking_model(cls, | |||
| tenant_id: str, | |||
| model_provider_name: Optional[str] = None, | |||
| model_name: Optional[str] = None) -> Optional[BaseReranking]: | |||
| """ | |||
| get reranking model. | |||
| :param tenant_id: a string representing the ID of the tenant. | |||
| :param model_provider_name: | |||
| :param model_name: | |||
| :return: | |||
| """ | |||
| if model_provider_name is None and model_name is None: | |||
| default_model = cls.get_default_model(tenant_id, ModelType.RERANKING) | |||
| if not default_model: | |||
| raise LLMBadRequestError(f"Default model is not available. " | |||
| f"Please configure a Default Reranking Model " | |||
| f"in the Settings -> Model Provider.") | |||
| model_provider_name = default_model.provider_name | |||
| model_name = default_model.model_name | |||
| # get model provider | |||
| model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) | |||
| if not model_provider: | |||
| raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") | |||
| # init reranking model | |||
| model_class = model_provider.get_model_class(model_type=ModelType.RERANKING) | |||
| return model_class( | |||
| model_provider=model_provider, | |||
| name=model_name | |||
| ) | |||
| @classmethod | |||
| def get_speech2text_model(cls, | |||
| tenant_id: str, | |||
| @@ -72,6 +72,9 @@ class ModelProviderFactory: | |||
| elif provider_name == 'localai': | |||
| from core.model_providers.providers.localai_provider import LocalAIProvider | |||
| return LocalAIProvider | |||
| elif provider_name == 'cohere': | |||
| from core.model_providers.providers.cohere_provider import CohereProvider | |||
| return CohereProvider | |||
| else: | |||
| raise NotImplementedError | |||
| @@ -17,7 +17,7 @@ class ModelType(enum.Enum): | |||
| IMAGE = 'image' | |||
| VIDEO = 'video' | |||
| MODERATION = 'moderation' | |||
| RERANKING = 'reranking' | |||
| @staticmethod | |||
| def value_of(value): | |||
| for member in ModelType: | |||
| @@ -0,0 +1,36 @@ | |||
| from abc import abstractmethod | |||
| from typing import Any, Optional, List | |||
| from langchain.schema import Document | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.entity.model_params import ModelType | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| import logging | |||
| logger = logging.getLogger(__name__) | |||
| class BaseReranking(BaseProviderModel): | |||
| name: str | |||
| type: ModelType = ModelType.RERANKING | |||
| def __init__(self, model_provider: BaseModelProvider, client: Any, name: str): | |||
| super().__init__(model_provider, client) | |||
| self.name = name | |||
| @property | |||
| def base_model_name(self) -> str: | |||
| """ | |||
| get base model name | |||
| :return: str | |||
| """ | |||
| return self.name | |||
| @abstractmethod | |||
| def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| raise NotImplementedError | |||
| @@ -0,0 +1,73 @@ | |||
| import logging | |||
| from typing import Optional, List | |||
| import cohere | |||
| import openai | |||
| from langchain.schema import Document | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ | |||
| LLMRateLimitError, LLMAuthorizationError | |||
| from core.model_providers.models.reranking.base import BaseReranking | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| class CohereReranking(BaseReranking): | |||
| def __init__(self, model_provider: BaseModelProvider, name: str): | |||
| self.credentials = model_provider.get_model_credentials( | |||
| model_name=name, | |||
| model_type=self.type | |||
| ) | |||
| client = cohere.Client(self.credentials.get('api_key')) | |||
| super().__init__(model_provider, client, name) | |||
| def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]: | |||
| docs = [] | |||
| doc_id = [] | |||
| for document in documents: | |||
| if document.metadata['doc_id'] not in doc_id: | |||
| doc_id.append(document.metadata['doc_id']) | |||
| docs.append(document.page_content) | |||
| results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k) | |||
| rerank_documents = [] | |||
| for idx, result in enumerate(results): | |||
| # format document | |||
| rerank_document = Document( | |||
| page_content=result.document['text'], | |||
| metadata={ | |||
| "doc_id": documents[result.index].metadata['doc_id'], | |||
| "doc_hash": documents[result.index].metadata['doc_hash'], | |||
| "document_id": documents[result.index].metadata['document_id'], | |||
| "dataset_id": documents[result.index].metadata['dataset_id'], | |||
| 'score': result.relevance_score | |||
| } | |||
| ) | |||
| # score threshold check | |||
| if score_threshold is not None: | |||
| if result.relevance_score >= score_threshold: | |||
| rerank_documents.append(rerank_document) | |||
| else: | |||
| rerank_documents.append(rerank_document) | |||
| return rerank_documents | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| if isinstance(ex, openai.error.InvalidRequestError): | |||
| logging.warning("Invalid request to OpenAI API.") | |||
| return LLMBadRequestError(str(ex)) | |||
| elif isinstance(ex, openai.error.APIConnectionError): | |||
| logging.warning("Failed to connect to OpenAI API.") | |||
| return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex)) | |||
| elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)): | |||
| logging.warning("OpenAI service unavailable.") | |||
| return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex)) | |||
| elif isinstance(ex, openai.error.RateLimitError): | |||
| return LLMRateLimitError(str(ex)) | |||
| elif isinstance(ex, openai.error.AuthenticationError): | |||
| return LLMAuthorizationError(str(ex)) | |||
| elif isinstance(ex, openai.error.OpenAIError): | |||
| return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex)) | |||
| else: | |||
| return ex | |||
| @@ -0,0 +1,152 @@ | |||
| import json | |||
| from json import JSONDecodeError | |||
| from typing import Type | |||
| from langchain.schema import HumanMessage | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode | |||
| from core.model_providers.models.reranking.cohere_reranking import CohereReranking | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| from models.provider import ProviderType | |||
| class CohereProvider(BaseModelProvider): | |||
| @property | |||
| def provider_name(self): | |||
| """ | |||
| Returns the name of a provider. | |||
| """ | |||
| return 'cohere' | |||
| def _get_text_generation_model_mode(self, model_name) -> str: | |||
| return ModelMode.CHAT.value | |||
| def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: | |||
| if model_type == ModelType.RERANKING: | |||
| return [ | |||
| { | |||
| 'id': 'rerank-english-v2.0', | |||
| 'name': 'rerank-english-v2.0' | |||
| }, | |||
| { | |||
| 'id': 'rerank-multilingual-v2.0', | |||
| 'name': 'rerank-multilingual-v2.0' | |||
| } | |||
| ] | |||
| else: | |||
| return [] | |||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||
| """ | |||
| Returns the model class. | |||
| :param model_type: | |||
| :return: | |||
| """ | |||
| if model_type == ModelType.RERANKING: | |||
| model_class = CohereReranking | |||
| else: | |||
| raise NotImplementedError | |||
| return model_class | |||
| def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: | |||
| """ | |||
| get model parameter rules. | |||
| :param model_name: | |||
| :param model_type: | |||
| :return: | |||
| """ | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0, max=1, default=0.3, precision=2), | |||
| top_p=KwargRule[float](min=0, max=0.99, default=0.85, precision=2), | |||
| presence_penalty=KwargRule[float](enabled=False), | |||
| frequency_penalty=KwargRule[float](enabled=False), | |||
| max_tokens=KwargRule[int](enabled=False), | |||
| ) | |||
| @classmethod | |||
| def is_provider_credentials_valid_or_raise(cls, credentials: dict): | |||
| """ | |||
| Validates the given credentials. | |||
| """ | |||
| if 'api_key' not in credentials: | |||
| raise CredentialsValidateFailedError('Cohere api_key must be provided.') | |||
| try: | |||
| credential_kwargs = { | |||
| 'api_key': credentials['api_key'], | |||
| } | |||
| # todo validate | |||
| except Exception as ex: | |||
| raise CredentialsValidateFailedError(str(ex)) | |||
| @classmethod | |||
| def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: | |||
| credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key']) | |||
| return credentials | |||
| def get_provider_credentials(self, obfuscated: bool = False) -> dict: | |||
| if self.provider.provider_type == ProviderType.CUSTOM.value: | |||
| try: | |||
| credentials = json.loads(self.provider.encrypted_config) | |||
| except JSONDecodeError: | |||
| credentials = { | |||
| 'api_key': None, | |||
| } | |||
| if credentials['api_key']: | |||
| credentials['api_key'] = encrypter.decrypt_token( | |||
| self.provider.tenant_id, | |||
| credentials['api_key'] | |||
| ) | |||
| if obfuscated: | |||
| credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key']) | |||
| return credentials | |||
| else: | |||
| return {} | |||
| def should_deduct_quota(self): | |||
| return True | |||
| @classmethod | |||
| def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): | |||
| """ | |||
| check model credentials valid. | |||
| :param model_name: | |||
| :param model_type: | |||
| :param credentials: | |||
| """ | |||
| return | |||
| @classmethod | |||
| def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, | |||
| credentials: dict) -> dict: | |||
| """ | |||
| encrypt model credentials for save. | |||
| :param tenant_id: | |||
| :param model_name: | |||
| :param model_type: | |||
| :param credentials: | |||
| :return: | |||
| """ | |||
| return {} | |||
| def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: | |||
| """ | |||
| get credentials for llm use. | |||
| :param model_name: | |||
| :param model_type: | |||
| :param obfuscated: | |||
| :return: | |||
| """ | |||
| return self.get_provider_credentials(obfuscated) | |||
| @@ -13,5 +13,6 @@ | |||
| "huggingface_hub", | |||
| "xinference", | |||
| "openllm", | |||
| "localai" | |||
| "localai", | |||
| "cohere" | |||
| ] | |||
| @@ -0,0 +1,7 @@ | |||
| { | |||
| "support_provider_types": [ | |||
| "custom" | |||
| ], | |||
| "system_config": null, | |||
| "model_flexibility": "fixed" | |||
| } | |||
| @@ -1,11 +1,17 @@ | |||
| from typing import Optional | |||
| import json | |||
| import threading | |||
| from typing import Optional, List | |||
| from flask import Flask | |||
| from langchain import WikipediaAPIWrapper | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.memory.chat_memory import BaseChatMemory | |||
| from langchain.tools import BaseTool, Tool, WikipediaQueryRun | |||
| from pydantic import BaseModel, Field | |||
| from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent | |||
| from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser | |||
| from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent | |||
| from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration | |||
| from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler | |||
| from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler | |||
| @@ -17,6 +23,7 @@ from core.model_providers.model_factory import ModelFactory | |||
| from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.tool.current_datetime_tool import DatetimeTool | |||
| from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool | |||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | |||
| from core.tool.provider.serpapi_provider import SerpAPIToolProvider | |||
| from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput | |||
| @@ -25,6 +32,16 @@ from extensions.ext_database import db | |||
| from models.dataset import Dataset, DatasetProcessRule | |||
| from models.model import AppModelConfig | |||
| default_retrieval_model = { | |||
| 'search_method': 'semantic_search', | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| 'reranking_model_name': '' | |||
| }, | |||
| 'top_k': 2, | |||
| 'score_threshold_enable': False | |||
| } | |||
| class OrchestratorRuleParser: | |||
| """Parse the orchestrator rule to entities.""" | |||
| @@ -34,7 +51,7 @@ class OrchestratorRuleParser: | |||
| self.app_model_config = app_model_config | |||
| def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory], | |||
| rest_tokens: int, chain_callback: MainChainGatherCallbackHandler, | |||
| rest_tokens: int, chain_callback: MainChainGatherCallbackHandler, tenant_id: str, | |||
| retriever_from: str = 'dev') -> Optional[AgentExecutor]: | |||
| if not self.app_model_config.agent_mode_dict: | |||
| return None | |||
| @@ -101,7 +118,8 @@ class OrchestratorRuleParser: | |||
| rest_tokens=rest_tokens, | |||
| return_resource=return_resource, | |||
| retriever_from=retriever_from, | |||
| dataset_configs=dataset_configs | |||
| dataset_configs=dataset_configs, | |||
| tenant_id=tenant_id | |||
| ) | |||
| if len(tools) == 0: | |||
| @@ -123,7 +141,7 @@ class OrchestratorRuleParser: | |||
| return chain | |||
| def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]: | |||
| def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]: | |||
| """ | |||
| Convert app agent tool configs to tools | |||
| @@ -132,6 +150,7 @@ class OrchestratorRuleParser: | |||
| :return: | |||
| """ | |||
| tools = [] | |||
| dataset_tools = [] | |||
| for tool_config in tool_configs: | |||
| tool_type = list(tool_config.keys())[0] | |||
| tool_val = list(tool_config.values())[0] | |||
| @@ -140,7 +159,7 @@ class OrchestratorRuleParser: | |||
| tool = None | |||
| if tool_type == "dataset": | |||
| tool = self.to_dataset_retriever_tool(tool_config=tool_val, **kwargs) | |||
| dataset_tools.append(tool_config) | |||
| elif tool_type == "web_reader": | |||
| tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs) | |||
| elif tool_type == "google_search": | |||
| @@ -156,57 +175,81 @@ class OrchestratorRuleParser: | |||
| else: | |||
| tool.callbacks = callbacks | |||
| tools.append(tool) | |||
| # format dataset tool | |||
| if len(dataset_tools) > 0: | |||
| dataset_retriever_tools = self.to_dataset_retriever_tool(tool_configs=dataset_tools, **kwargs) | |||
| if dataset_retriever_tools: | |||
| tools.extend(dataset_retriever_tools) | |||
| return tools | |||
| def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask, | |||
| dataset_configs: dict, rest_tokens: int, | |||
| def to_dataset_retriever_tool(self, tool_configs: List, conversation_message_task: ConversationMessageTask, | |||
| return_resource: bool = False, retriever_from: str = 'dev', | |||
| **kwargs) \ | |||
| -> Optional[BaseTool]: | |||
| -> Optional[List[BaseTool]]: | |||
| """ | |||
| A dataset tool is a tool that can be used to retrieve information from a dataset | |||
| :param rest_tokens: | |||
| :param tool_config: | |||
| :param dataset_configs: | |||
| :param tool_configs: | |||
| :param conversation_message_task: | |||
| :param return_resource: | |||
| :param retriever_from: | |||
| :return: | |||
| """ | |||
| # get dataset from dataset id | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == self.tenant_id, | |||
| Dataset.id == tool_config.get("id") | |||
| ).first() | |||
| if not dataset: | |||
| return None | |||
| if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0: | |||
| return None | |||
| top_k = dataset_configs.get("top_k", 2) | |||
| # dynamically adjust top_k when the remaining token number is not enough to support top_k | |||
| top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens) | |||
| dataset_configs = kwargs['dataset_configs'] | |||
| retrieval_model = dataset_configs.get('retrieval_model', 'single') | |||
| tools = [] | |||
| dataset_ids = [] | |||
| tenant_id = None | |||
| for tool_config in tool_configs: | |||
| # get dataset from dataset id | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == self.tenant_id, | |||
| Dataset.id == tool_config.get('dataset').get("id") | |||
| ).first() | |||
| score_threshold = None | |||
| score_threshold_config = dataset_configs.get("score_threshold") | |||
| if score_threshold_config and score_threshold_config.get("enable"): | |||
| score_threshold = score_threshold_config.get("value") | |||
| if not dataset: | |||
| return None | |||
| tool = DatasetRetrieverTool.from_dataset( | |||
| dataset=dataset, | |||
| top_k=top_k, | |||
| score_threshold=score_threshold, | |||
| callbacks=[DatasetToolCallbackHandler(conversation_message_task)], | |||
| conversation_message_task=conversation_message_task, | |||
| return_resource=return_resource, | |||
| retriever_from=retriever_from | |||
| ) | |||
| if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0: | |||
| return None | |||
| dataset_ids.append(dataset.id) | |||
| if retrieval_model == 'single': | |||
| retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | |||
| top_k = retrieval_model['top_k'] | |||
| # dynamically adjust top_k when the remaining token number is not enough to support top_k | |||
| # top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens) | |||
| score_threshold = None | |||
| score_threshold_enable = retrieval_model.get("score_threshold_enable") | |||
| if score_threshold_enable: | |||
| score_threshold = retrieval_model.get("score_threshold") | |||
| tool = DatasetRetrieverTool.from_dataset( | |||
| dataset=dataset, | |||
| top_k=top_k, | |||
| score_threshold=score_threshold, | |||
| callbacks=[DatasetToolCallbackHandler(conversation_message_task)], | |||
| conversation_message_task=conversation_message_task, | |||
| return_resource=return_resource, | |||
| retriever_from=retriever_from | |||
| ) | |||
| tools.append(tool) | |||
| if retrieval_model == 'multiple': | |||
| tool = DatasetMultiRetrieverTool.from_dataset( | |||
| dataset_ids=dataset_ids, | |||
| tenant_id=kwargs['tenant_id'], | |||
| top_k=dataset_configs.get('top_k', 2), | |||
| score_threshold=dataset_configs.get('score_threshold', 0.5) if dataset_configs.get('score_threshold_enable', False) else None, | |||
| callbacks=[DatasetToolCallbackHandler(conversation_message_task)], | |||
| conversation_message_task=conversation_message_task, | |||
| return_resource=return_resource, | |||
| retriever_from=retriever_from, | |||
| reranking_provider_name=dataset_configs.get('reranking_model').get('reranking_provider_name'), | |||
| reranking_model_name=dataset_configs.get('reranking_model').get('reranking_model_name') | |||
| ) | |||
| tools.append(tool) | |||
| return tool | |||
| return tools | |||
| def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]: | |||
| """ | |||
| @@ -0,0 +1,227 @@ | |||
| import json | |||
| import threading | |||
| from typing import Type, Optional, List | |||
| from flask import current_app, Flask | |||
| from langchain.tools import BaseTool | |||
| from pydantic import Field, BaseModel | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| from core.conversation_message_task import ConversationMessageTask | |||
| from core.embedding.cached_embedding import CacheEmbedding | |||
| from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig | |||
| from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, DocumentSegment, Document | |||
| from services.retrieval_service import RetrievalService | |||
| default_retrieval_model = { | |||
| 'search_method': 'semantic_search', | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| 'reranking_model_name': '' | |||
| }, | |||
| 'top_k': 2, | |||
| 'score_threshold_enable': False | |||
| } | |||
| class DatasetMultiRetrieverToolInput(BaseModel): | |||
| query: str = Field(..., description="dataset multi retriever and rerank") | |||
| class DatasetMultiRetrieverTool(BaseTool): | |||
| """Tool for querying multi dataset.""" | |||
| name: str = "dataset-" | |||
| args_schema: Type[BaseModel] = DatasetMultiRetrieverToolInput | |||
| description: str = "dataset multi retriever and rerank. " | |||
| tenant_id: str | |||
| dataset_ids: List[str] | |||
| top_k: int = 2 | |||
| score_threshold: Optional[float] = None | |||
| reranking_provider_name: str | |||
| reranking_model_name: str | |||
| conversation_message_task: ConversationMessageTask | |||
| return_resource: bool | |||
| retriever_from: str | |||
| @classmethod | |||
| def from_dataset(cls, dataset_ids: List[str], tenant_id: str, **kwargs): | |||
| return cls( | |||
| name=f'dataset-{tenant_id}', | |||
| tenant_id=tenant_id, | |||
| dataset_ids=dataset_ids, | |||
| **kwargs | |||
| ) | |||
| def _run(self, query: str) -> str: | |||
| threads = [] | |||
| all_documents = [] | |||
| for dataset_id in self.dataset_ids: | |||
| retrieval_thread = threading.Thread(target=self._retriever, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset_id': dataset_id, | |||
| 'query': query, | |||
| 'all_documents': all_documents | |||
| }) | |||
| threads.append(retrieval_thread) | |||
| retrieval_thread.start() | |||
| for thread in threads: | |||
| thread.join() | |||
| # do rerank for searched documents | |||
| rerank = ModelFactory.get_reranking_model( | |||
| tenant_id=self.tenant_id, | |||
| model_provider_name=self.reranking_provider_name, | |||
| model_name=self.reranking_model_name | |||
| ) | |||
| all_documents = rerank.rerank(query, all_documents, self.score_threshold, self.top_k) | |||
| hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task) | |||
| hit_callback.on_tool_end(all_documents) | |||
| document_context_list = [] | |||
| index_node_ids = [document.metadata['doc_id'] for document in all_documents] | |||
| segments = DocumentSegment.query.filter( | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.status == 'completed', | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.index_node_id.in_(index_node_ids) | |||
| ).all() | |||
| if segments: | |||
| index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} | |||
| sorted_segments = sorted(segments, | |||
| key=lambda segment: index_node_id_to_position.get(segment.index_node_id, | |||
| float('inf'))) | |||
| for segment in sorted_segments: | |||
| if segment.answer: | |||
| document_context_list.append(f'question:{segment.content} answer:{segment.answer}') | |||
| else: | |||
| document_context_list.append(segment.content) | |||
| if self.return_resource: | |||
| context_list = [] | |||
| resource_number = 1 | |||
| for segment in sorted_segments: | |||
| dataset = Dataset.query.filter_by( | |||
| id=segment.dataset_id | |||
| ).first() | |||
| document = Document.query.filter(Document.id == segment.document_id, | |||
| Document.enabled == True, | |||
| Document.archived == False, | |||
| ).first() | |||
| if dataset and document: | |||
| source = { | |||
| 'position': resource_number, | |||
| 'dataset_id': dataset.id, | |||
| 'dataset_name': dataset.name, | |||
| 'document_id': document.id, | |||
| 'document_name': document.name, | |||
| 'data_source_type': document.data_source_type, | |||
| 'segment_id': segment.id, | |||
| 'retriever_from': self.retriever_from | |||
| } | |||
| if self.retriever_from == 'dev': | |||
| source['hit_count'] = segment.hit_count | |||
| source['word_count'] = segment.word_count | |||
| source['segment_position'] = segment.position | |||
| source['index_node_hash'] = segment.index_node_hash | |||
| if segment.answer: | |||
| source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' | |||
| else: | |||
| source['content'] = segment.content | |||
| context_list.append(source) | |||
| resource_number += 1 | |||
| hit_callback.return_retriever_resource_info(context_list) | |||
| return str("\n".join(document_context_list)) | |||
| async def _arun(self, tool_input: str) -> str: | |||
| raise NotImplementedError() | |||
| def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: List): | |||
| with flask_app.app_context(): | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == self.tenant_id, | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| if not dataset: | |||
| return [] | |||
| # get retrieval model , if the model is not setting , using default | |||
| retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | |||
| if dataset.indexing_technique == "economy": | |||
| # use keyword table query | |||
| kw_table_index = KeywordTableIndex( | |||
| dataset=dataset, | |||
| config=KeywordTableConfig( | |||
| max_keywords_per_chunk=5 | |||
| ) | |||
| ) | |||
| documents = kw_table_index.search(query, search_kwargs={'k': self.top_k}) | |||
| if documents: | |||
| all_documents.extend(documents) | |||
| else: | |||
| try: | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| tenant_id=dataset.tenant_id, | |||
| model_provider_name=dataset.embedding_model_provider, | |||
| model_name=dataset.embedding_model | |||
| ) | |||
| except LLMBadRequestError: | |||
| return [] | |||
| except ProviderTokenNotInitError: | |||
| return [] | |||
| embeddings = CacheEmbedding(embedding_model) | |||
| documents = [] | |||
| threads = [] | |||
| if self.top_k > 0: | |||
| # retrieval_model source with semantic | |||
| if retrieval_model['search_method'] == 'semantic_search' or retrieval_model[ | |||
| 'search_method'] == 'hybrid_search': | |||
| embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset': dataset, | |||
| 'query': query, | |||
| 'top_k': self.top_k, | |||
| 'score_threshold': self.score_threshold, | |||
| 'reranking_model': None, | |||
| 'all_documents': documents, | |||
| 'search_method': 'hybrid_search', | |||
| 'embeddings': embeddings | |||
| }) | |||
| threads.append(embedding_thread) | |||
| embedding_thread.start() | |||
| # retrieval_model source with full text | |||
| if retrieval_model['search_method'] == 'full_text_search' or retrieval_model[ | |||
| 'search_method'] == 'hybrid_search': | |||
| full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, | |||
| kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset': dataset, | |||
| 'query': query, | |||
| 'search_method': 'hybrid_search', | |||
| 'embeddings': embeddings, | |||
| 'score_threshold': retrieval_model[ | |||
| 'score_threshold'] if retrieval_model[ | |||
| 'score_threshold_enable'] else None, | |||
| 'top_k': self.top_k, | |||
| 'reranking_model': retrieval_model[ | |||
| 'reranking_model'] if retrieval_model[ | |||
| 'reranking_enable'] else None, | |||
| 'all_documents': documents | |||
| }) | |||
| threads.append(full_text_index_thread) | |||
| full_text_index_thread.start() | |||
| for thread in threads: | |||
| thread.join() | |||
| all_documents.extend(documents) | |||
| @@ -1,5 +1,6 @@ | |||
| import json | |||
| from typing import Type, Optional | |||
| import threading | |||
| from typing import Type, Optional, List | |||
| from flask import current_app | |||
| from langchain.tools import BaseTool | |||
| @@ -14,6 +15,18 @@ from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitE | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, DocumentSegment, Document | |||
| from services.retrieval_service import RetrievalService | |||
| default_retrieval_model = { | |||
| 'search_method': 'semantic_search', | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| 'reranking_model_name': '' | |||
| }, | |||
| 'top_k': 2, | |||
| 'score_threshold_enable': False | |||
| } | |||
| class DatasetRetrieverToolInput(BaseModel): | |||
| @@ -56,7 +69,9 @@ class DatasetRetrieverTool(BaseTool): | |||
| ).first() | |||
| if not dataset: | |||
| return f'[{self.name} failed to find dataset with id {self.dataset_id}.]' | |||
| return '' | |||
| # get retrieval model , if the model is not setting , using default | |||
| retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | |||
| if dataset.indexing_technique == "economy": | |||
| # use keyword table query | |||
| @@ -83,28 +98,62 @@ class DatasetRetrieverTool(BaseTool): | |||
| return '' | |||
| embeddings = CacheEmbedding(embedding_model) | |||
| vector_index = VectorIndex( | |||
| dataset=dataset, | |||
| config=current_app.config, | |||
| embeddings=embeddings | |||
| ) | |||
| documents = [] | |||
| threads = [] | |||
| if self.top_k > 0: | |||
| documents = vector_index.search( | |||
| query, | |||
| search_type='similarity_score_threshold', | |||
| search_kwargs={ | |||
| 'k': self.top_k, | |||
| 'score_threshold': self.score_threshold, | |||
| 'filter': { | |||
| 'group_id': [dataset.id] | |||
| } | |||
| } | |||
| ) | |||
| # retrieval source with semantic | |||
| if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': | |||
| embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset': dataset, | |||
| 'query': query, | |||
| 'top_k': self.top_k, | |||
| 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ | |||
| 'score_threshold_enable'] else None, | |||
| 'reranking_model': retrieval_model['reranking_model'] if retrieval_model[ | |||
| 'reranking_enable'] else None, | |||
| 'all_documents': documents, | |||
| 'search_method': retrieval_model['search_method'], | |||
| 'embeddings': embeddings | |||
| }) | |||
| threads.append(embedding_thread) | |||
| embedding_thread.start() | |||
| # retrieval_model source with full text | |||
| if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': | |||
| full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset': dataset, | |||
| 'query': query, | |||
| 'search_method': retrieval_model['search_method'], | |||
| 'embeddings': embeddings, | |||
| 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ | |||
| 'score_threshold_enable'] else None, | |||
| 'top_k': self.top_k, | |||
| 'reranking_model': retrieval_model['reranking_model'] if retrieval_model[ | |||
| 'reranking_enable'] else None, | |||
| 'all_documents': documents | |||
| }) | |||
| threads.append(full_text_index_thread) | |||
| full_text_index_thread.start() | |||
| for thread in threads: | |||
| thread.join() | |||
| # hybrid search: rerank after all documents have been searched | |||
| if retrieval_model['search_method'] == 'hybrid_search': | |||
| hybrid_rerank = ModelFactory.get_reranking_model( | |||
| tenant_id=dataset.tenant_id, | |||
| model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'], | |||
| model_name=retrieval_model['reranking_model']['reranking_model_name'] | |||
| ) | |||
| documents = hybrid_rerank.rerank(query, documents, | |||
| retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, | |||
| self.top_k) | |||
| else: | |||
| documents = [] | |||
| hit_callback = DatasetIndexToolCallbackHandler(dataset.id, self.conversation_message_task) | |||
| hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task) | |||
| hit_callback.on_tool_end(documents) | |||
| document_score_list = {} | |||
| if dataset.indexing_technique != "economy": | |||
| @@ -1,4 +1,4 @@ | |||
| from core.index.vector_index.milvus import Milvus | |||
| from core.vector_store.vector.milvus import Milvus | |||
| class MilvusVectorStore(Milvus): | |||
| @@ -4,7 +4,7 @@ from langchain.schema import Document | |||
| from qdrant_client.http.models import Filter, PointIdsList, FilterSelector | |||
| from qdrant_client.local.qdrant_local import QdrantLocal | |||
| from core.index.vector_index.qdrant import Qdrant | |||
| from core.vector_store.vector.qdrant import Qdrant | |||
| class QdrantVectorStore(Qdrant): | |||
| @@ -73,3 +73,4 @@ class QdrantVectorStore(Qdrant): | |||
| if isinstance(self.client, QdrantLocal): | |||
| self.client = cast(QdrantLocal, self.client) | |||
| self.client._load() | |||
| @@ -28,7 +28,7 @@ from langchain.docstore.document import Document | |||
| from langchain.embeddings.base import Embeddings | |||
| from langchain.vectorstores import VectorStore | |||
| from langchain.vectorstores.utils import maximal_marginal_relevance | |||
| from qdrant_client.http.models import PayloadSchemaType | |||
| from qdrant_client.http.models import PayloadSchemaType, FilterSelector, TextIndexParams, TokenizerType, TextIndexType | |||
| if TYPE_CHECKING: | |||
| from qdrant_client import grpc # noqa | |||
| @@ -189,14 +189,25 @@ class Qdrant(VectorStore): | |||
| texts, metadatas, ids, batch_size | |||
| ): | |||
| self.client.upsert( | |||
| collection_name=self.collection_name, points=points, **kwargs | |||
| collection_name=self.collection_name, points=points | |||
| ) | |||
| added_ids.extend(batch_ids) | |||
| # if is new collection, create payload index on group_id | |||
| if self.is_new_collection: | |||
| # create payload index | |||
| self.client.create_payload_index(self.collection_name, self.group_payload_key, | |||
| field_schema=PayloadSchemaType.KEYWORD, | |||
| field_type=PayloadSchemaType.KEYWORD) | |||
| # creat full text index | |||
| text_index_params = TextIndexParams( | |||
| type=TextIndexType.TEXT, | |||
| tokenizer=TokenizerType.MULTILINGUAL, | |||
| min_token_len=2, | |||
| max_token_len=20, | |||
| lowercase=True | |||
| ) | |||
| self.client.create_payload_index(self.collection_name, self.content_payload_key, | |||
| field_schema=text_index_params) | |||
| return added_ids | |||
| @sync_call_fallback | |||
| @@ -600,7 +611,7 @@ class Qdrant(VectorStore): | |||
| limit=k, | |||
| offset=offset, | |||
| with_payload=True, | |||
| with_vectors=True, # Langchain does not expect vectors to be returned | |||
| with_vectors=True, | |||
| score_threshold=score_threshold, | |||
| consistency=consistency, | |||
| **kwargs, | |||
| @@ -615,6 +626,39 @@ class Qdrant(VectorStore): | |||
| for result in results | |||
| ] | |||
| def similarity_search_by_bm25( | |||
| self, | |||
| filter: Optional[MetadataFilter] = None, | |||
| k: int = 4 | |||
| ) -> List[Document]: | |||
| """Return docs most similar by bm25. | |||
| Args: | |||
| embedding: Embedding vector to look up documents similar to. | |||
| k: Number of Documents to return. Defaults to 4. | |||
| filter: Filter by metadata. Defaults to None. | |||
| search_params: Additional search params | |||
| Returns: | |||
| List of documents most similar to the query text and distance for each. | |||
| """ | |||
| response = self.client.scroll( | |||
| collection_name=self.collection_name, | |||
| scroll_filter=filter, | |||
| limit=k, | |||
| with_payload=True, | |||
| with_vectors=True | |||
| ) | |||
| results = response[0] | |||
| documents = [] | |||
| for result in results: | |||
| if result: | |||
| documents.append(self._document_from_scored_point( | |||
| result, self.content_payload_key, self.metadata_payload_key | |||
| )) | |||
| return documents | |||
| @sync_call_fallback | |||
| async def asimilarity_search_with_score_by_vector( | |||
| self, | |||
| @@ -0,0 +1,505 @@ | |||
| """Wrapper around weaviate vector database.""" | |||
| from __future__ import annotations | |||
| import datetime | |||
| from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type | |||
| from uuid import uuid4 | |||
| import numpy as np | |||
| from langchain.docstore.document import Document | |||
| from langchain.embeddings.base import Embeddings | |||
| from langchain.utils import get_from_dict_or_env | |||
| from langchain.vectorstores.base import VectorStore | |||
| from langchain.vectorstores.utils import maximal_marginal_relevance | |||
| def _default_schema(index_name: str) -> Dict: | |||
| return { | |||
| "class": index_name, | |||
| "properties": [ | |||
| { | |||
| "name": "text", | |||
| "dataType": ["text"], | |||
| } | |||
| ], | |||
| } | |||
| def _create_weaviate_client(**kwargs: Any) -> Any: | |||
| client = kwargs.get("client") | |||
| if client is not None: | |||
| return client | |||
| weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL") | |||
| try: | |||
| # the weaviate api key param should not be mandatory | |||
| weaviate_api_key = get_from_dict_or_env( | |||
| kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None | |||
| ) | |||
| except ValueError: | |||
| weaviate_api_key = None | |||
| try: | |||
| import weaviate | |||
| except ImportError: | |||
| raise ValueError( | |||
| "Could not import weaviate python package. " | |||
| "Please install it with `pip install weaviate-client`" | |||
| ) | |||
| auth = ( | |||
| weaviate.auth.AuthApiKey(api_key=weaviate_api_key) | |||
| if weaviate_api_key is not None | |||
| else None | |||
| ) | |||
| client = weaviate.Client(weaviate_url, auth_client_secret=auth) | |||
| return client | |||
| def _default_score_normalizer(val: float) -> float: | |||
| return 1 - 1 / (1 + np.exp(val)) | |||
| def _json_serializable(value: Any) -> Any: | |||
| if isinstance(value, datetime.datetime): | |||
| return value.isoformat() | |||
| return value | |||
| class Weaviate(VectorStore): | |||
| """Wrapper around Weaviate vector database. | |||
| To use, you should have the ``weaviate-client`` python package installed. | |||
| Example: | |||
| .. code-block:: python | |||
| import weaviate | |||
| from langchain.vectorstores import Weaviate | |||
| client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...) | |||
| weaviate = Weaviate(client, index_name, text_key) | |||
| """ | |||
| def __init__( | |||
| self, | |||
| client: Any, | |||
| index_name: str, | |||
| text_key: str, | |||
| embedding: Optional[Embeddings] = None, | |||
| attributes: Optional[List[str]] = None, | |||
| relevance_score_fn: Optional[ | |||
| Callable[[float], float] | |||
| ] = _default_score_normalizer, | |||
| by_text: bool = True, | |||
| ): | |||
| """Initialize with Weaviate client.""" | |||
| try: | |||
| import weaviate | |||
| except ImportError: | |||
| raise ValueError( | |||
| "Could not import weaviate python package. " | |||
| "Please install it with `pip install weaviate-client`." | |||
| ) | |||
| if not isinstance(client, weaviate.Client): | |||
| raise ValueError( | |||
| f"client should be an instance of weaviate.Client, got {type(client)}" | |||
| ) | |||
| self._client = client | |||
| self._index_name = index_name | |||
| self._embedding = embedding | |||
| self._text_key = text_key | |||
| self._query_attrs = [self._text_key] | |||
| self.relevance_score_fn = relevance_score_fn | |||
| self._by_text = by_text | |||
| if attributes is not None: | |||
| self._query_attrs.extend(attributes) | |||
| @property | |||
| def embeddings(self) -> Optional[Embeddings]: | |||
| return self._embedding | |||
| def _select_relevance_score_fn(self) -> Callable[[float], float]: | |||
| return ( | |||
| self.relevance_score_fn | |||
| if self.relevance_score_fn | |||
| else _default_score_normalizer | |||
| ) | |||
| def add_texts( | |||
| self, | |||
| texts: Iterable[str], | |||
| metadatas: Optional[List[dict]] = None, | |||
| **kwargs: Any, | |||
| ) -> List[str]: | |||
| """Upload texts with metadata (properties) to Weaviate.""" | |||
| from weaviate.util import get_valid_uuid | |||
| ids = [] | |||
| embeddings: Optional[List[List[float]]] = None | |||
| if self._embedding: | |||
| if not isinstance(texts, list): | |||
| texts = list(texts) | |||
| embeddings = self._embedding.embed_documents(texts) | |||
| with self._client.batch as batch: | |||
| for i, text in enumerate(texts): | |||
| data_properties = {self._text_key: text} | |||
| if metadatas is not None: | |||
| for key, val in metadatas[i].items(): | |||
| data_properties[key] = _json_serializable(val) | |||
| # Allow for ids (consistent w/ other methods) | |||
| # # Or uuids (backwards compatble w/ existing arg) | |||
| # If the UUID of one of the objects already exists | |||
| # then the existing object will be replaced by the new object. | |||
| _id = get_valid_uuid(uuid4()) | |||
| if "uuids" in kwargs: | |||
| _id = kwargs["uuids"][i] | |||
| elif "ids" in kwargs: | |||
| _id = kwargs["ids"][i] | |||
| batch.add_data_object( | |||
| data_object=data_properties, | |||
| class_name=self._index_name, | |||
| uuid=_id, | |||
| vector=embeddings[i] if embeddings else None, | |||
| ) | |||
| ids.append(_id) | |||
| return ids | |||
| def similarity_search( | |||
| self, query: str, k: int = 4, **kwargs: Any | |||
| ) -> List[Document]: | |||
| """Return docs most similar to query. | |||
| Args: | |||
| query: Text to look up documents similar to. | |||
| k: Number of Documents to return. Defaults to 4. | |||
| Returns: | |||
| List of Documents most similar to the query. | |||
| """ | |||
| if self._by_text: | |||
| return self.similarity_search_by_text(query, k, **kwargs) | |||
| else: | |||
| if self._embedding is None: | |||
| raise ValueError( | |||
| "_embedding cannot be None for similarity_search when " | |||
| "_by_text=False" | |||
| ) | |||
| embedding = self._embedding.embed_query(query) | |||
| return self.similarity_search_by_vector(embedding, k, **kwargs) | |||
| def similarity_search_by_text( | |||
| self, query: str, k: int = 4, **kwargs: Any | |||
| ) -> List[Document]: | |||
| """Return docs most similar to query. | |||
| Args: | |||
| query: Text to look up documents similar to. | |||
| k: Number of Documents to return. Defaults to 4. | |||
| Returns: | |||
| List of Documents most similar to the query. | |||
| """ | |||
| content: Dict[str, Any] = {"concepts": [query]} | |||
| if kwargs.get("search_distance"): | |||
| content["certainty"] = kwargs.get("search_distance") | |||
| query_obj = self._client.query.get(self._index_name, self._query_attrs) | |||
| if kwargs.get("where_filter"): | |||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||
| if kwargs.get("additional"): | |||
| query_obj = query_obj.with_additional(kwargs.get("additional")) | |||
| result = query_obj.with_near_text(content).with_limit(k).do() | |||
| if "errors" in result: | |||
| raise ValueError(f"Error during query: {result['errors']}") | |||
| docs = [] | |||
| for res in result["data"]["Get"][self._index_name]: | |||
| text = res.pop(self._text_key) | |||
| docs.append(Document(page_content=text, metadata=res)) | |||
| return docs | |||
| def similarity_search_by_bm25( | |||
| self, query: str, k: int = 4, **kwargs: Any | |||
| ) -> List[Document]: | |||
| """Return docs using BM25F. | |||
| Args: | |||
| query: Text to look up documents similar to. | |||
| k: Number of Documents to return. Defaults to 4. | |||
| Returns: | |||
| List of Documents most similar to the query. | |||
| """ | |||
| content: Dict[str, Any] = {"concepts": [query]} | |||
| if kwargs.get("search_distance"): | |||
| content["certainty"] = kwargs.get("search_distance") | |||
| query_obj = self._client.query.get(self._index_name, self._query_attrs) | |||
| if kwargs.get("where_filter"): | |||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||
| if kwargs.get("additional"): | |||
| query_obj = query_obj.with_additional(kwargs.get("additional")) | |||
| result = query_obj.with_bm25(query=content).with_limit(k).do() | |||
| if "errors" in result: | |||
| raise ValueError(f"Error during query: {result['errors']}") | |||
| docs = [] | |||
| for res in result["data"]["Get"][self._index_name]: | |||
| text = res.pop(self._text_key) | |||
| docs.append(Document(page_content=text, metadata=res)) | |||
| return docs | |||
| def similarity_search_by_vector( | |||
| self, embedding: List[float], k: int = 4, **kwargs: Any | |||
| ) -> List[Document]: | |||
| """Look up similar documents by embedding vector in Weaviate.""" | |||
| vector = {"vector": embedding} | |||
| query_obj = self._client.query.get(self._index_name, self._query_attrs) | |||
| if kwargs.get("where_filter"): | |||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||
| if kwargs.get("additional"): | |||
| query_obj = query_obj.with_additional(kwargs.get("additional")) | |||
| result = query_obj.with_near_vector(vector).with_limit(k).do() | |||
| if "errors" in result: | |||
| raise ValueError(f"Error during query: {result['errors']}") | |||
| docs = [] | |||
| for res in result["data"]["Get"][self._index_name]: | |||
| text = res.pop(self._text_key) | |||
| docs.append(Document(page_content=text, metadata=res)) | |||
| return docs | |||
| def max_marginal_relevance_search( | |||
| self, | |||
| query: str, | |||
| k: int = 4, | |||
| fetch_k: int = 20, | |||
| lambda_mult: float = 0.5, | |||
| **kwargs: Any, | |||
| ) -> List[Document]: | |||
| """Return docs selected using the maximal marginal relevance. | |||
| Maximal marginal relevance optimizes for similarity to query AND diversity | |||
| among selected documents. | |||
| Args: | |||
| query: Text to look up documents similar to. | |||
| k: Number of Documents to return. Defaults to 4. | |||
| fetch_k: Number of Documents to fetch to pass to MMR algorithm. | |||
| lambda_mult: Number between 0 and 1 that determines the degree | |||
| of diversity among the results with 0 corresponding | |||
| to maximum diversity and 1 to minimum diversity. | |||
| Defaults to 0.5. | |||
| Returns: | |||
| List of Documents selected by maximal marginal relevance. | |||
| """ | |||
| if self._embedding is not None: | |||
| embedding = self._embedding.embed_query(query) | |||
| else: | |||
| raise ValueError( | |||
| "max_marginal_relevance_search requires a suitable Embeddings object" | |||
| ) | |||
| return self.max_marginal_relevance_search_by_vector( | |||
| embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs | |||
| ) | |||
| def max_marginal_relevance_search_by_vector( | |||
| self, | |||
| embedding: List[float], | |||
| k: int = 4, | |||
| fetch_k: int = 20, | |||
| lambda_mult: float = 0.5, | |||
| **kwargs: Any, | |||
| ) -> List[Document]: | |||
| """Return docs selected using the maximal marginal relevance. | |||
| Maximal marginal relevance optimizes for similarity to query AND diversity | |||
| among selected documents. | |||
| Args: | |||
| embedding: Embedding to look up documents similar to. | |||
| k: Number of Documents to return. Defaults to 4. | |||
| fetch_k: Number of Documents to fetch to pass to MMR algorithm. | |||
| lambda_mult: Number between 0 and 1 that determines the degree | |||
| of diversity among the results with 0 corresponding | |||
| to maximum diversity and 1 to minimum diversity. | |||
| Defaults to 0.5. | |||
| Returns: | |||
| List of Documents selected by maximal marginal relevance. | |||
| """ | |||
| vector = {"vector": embedding} | |||
| query_obj = self._client.query.get(self._index_name, self._query_attrs) | |||
| if kwargs.get("where_filter"): | |||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||
| results = ( | |||
| query_obj.with_additional("vector") | |||
| .with_near_vector(vector) | |||
| .with_limit(fetch_k) | |||
| .do() | |||
| ) | |||
| payload = results["data"]["Get"][self._index_name] | |||
| embeddings = [result["_additional"]["vector"] for result in payload] | |||
| mmr_selected = maximal_marginal_relevance( | |||
| np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult | |||
| ) | |||
| docs = [] | |||
| for idx in mmr_selected: | |||
| text = payload[idx].pop(self._text_key) | |||
| payload[idx].pop("_additional") | |||
| meta = payload[idx] | |||
| docs.append(Document(page_content=text, metadata=meta)) | |||
| return docs | |||
| def similarity_search_with_score( | |||
| self, query: str, k: int = 4, **kwargs: Any | |||
| ) -> List[Tuple[Document, float]]: | |||
| """ | |||
| Return list of documents most similar to the query | |||
| text and cosine distance in float for each. | |||
| Lower score represents more similarity. | |||
| """ | |||
| if self._embedding is None: | |||
| raise ValueError( | |||
| "_embedding cannot be None for similarity_search_with_score" | |||
| ) | |||
| content: Dict[str, Any] = {"concepts": [query]} | |||
| if kwargs.get("search_distance"): | |||
| content["certainty"] = kwargs.get("search_distance") | |||
| query_obj = self._client.query.get(self._index_name, self._query_attrs) | |||
| embedded_query = self._embedding.embed_query(query) | |||
| if not self._by_text: | |||
| vector = {"vector": embedded_query} | |||
| result = ( | |||
| query_obj.with_near_vector(vector) | |||
| .with_limit(k) | |||
| .with_additional("vector") | |||
| .do() | |||
| ) | |||
| else: | |||
| result = ( | |||
| query_obj.with_near_text(content) | |||
| .with_limit(k) | |||
| .with_additional("vector") | |||
| .do() | |||
| ) | |||
| if "errors" in result: | |||
| raise ValueError(f"Error during query: {result['errors']}") | |||
| docs_and_scores = [] | |||
| for res in result["data"]["Get"][self._index_name]: | |||
| text = res.pop(self._text_key) | |||
| score = np.dot(res["_additional"]["vector"], embedded_query) | |||
| docs_and_scores.append((Document(page_content=text, metadata=res), score)) | |||
| return docs_and_scores | |||
| @classmethod | |||
| def from_texts( | |||
| cls: Type[Weaviate], | |||
| texts: List[str], | |||
| embedding: Embeddings, | |||
| metadatas: Optional[List[dict]] = None, | |||
| **kwargs: Any, | |||
| ) -> Weaviate: | |||
| """Construct Weaviate wrapper from raw documents. | |||
| This is a user-friendly interface that: | |||
| 1. Embeds documents. | |||
| 2. Creates a new index for the embeddings in the Weaviate instance. | |||
| 3. Adds the documents to the newly created Weaviate index. | |||
| This is intended to be a quick way to get started. | |||
| Example: | |||
| .. code-block:: python | |||
| from langchain.vectorstores.weaviate import Weaviate | |||
| from langchain.embeddings import OpenAIEmbeddings | |||
| embeddings = OpenAIEmbeddings() | |||
| weaviate = Weaviate.from_texts( | |||
| texts, | |||
| embeddings, | |||
| weaviate_url="http://localhost:8080" | |||
| ) | |||
| """ | |||
| client = _create_weaviate_client(**kwargs) | |||
| from weaviate.util import get_valid_uuid | |||
| index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}") | |||
| embeddings = embedding.embed_documents(texts) if embedding else None | |||
| text_key = "text" | |||
| schema = _default_schema(index_name) | |||
| attributes = list(metadatas[0].keys()) if metadatas else None | |||
| # check whether the index already exists | |||
| if not client.schema.contains(schema): | |||
| client.schema.create_class(schema) | |||
| with client.batch as batch: | |||
| for i, text in enumerate(texts): | |||
| data_properties = { | |||
| text_key: text, | |||
| } | |||
| if metadatas is not None: | |||
| for key in metadatas[i].keys(): | |||
| data_properties[key] = metadatas[i][key] | |||
| # If the UUID of one of the objects already exists | |||
| # then the existing objectwill be replaced by the new object. | |||
| if "uuids" in kwargs: | |||
| _id = kwargs["uuids"][i] | |||
| else: | |||
| _id = get_valid_uuid(uuid4()) | |||
| # if an embedding strategy is not provided, we let | |||
| # weaviate create the embedding. Note that this will only | |||
| # work if weaviate has been installed with a vectorizer module | |||
| # like text2vec-contextionary for example | |||
| params = { | |||
| "uuid": _id, | |||
| "data_object": data_properties, | |||
| "class_name": index_name, | |||
| } | |||
| if embeddings is not None: | |||
| params["vector"] = embeddings[i] | |||
| batch.add_data_object(**params) | |||
| batch.flush() | |||
| relevance_score_fn = kwargs.get("relevance_score_fn") | |||
| by_text: bool = kwargs.get("by_text", False) | |||
| return cls( | |||
| client, | |||
| index_name, | |||
| text_key, | |||
| embedding=embedding, | |||
| attributes=attributes, | |||
| relevance_score_fn=relevance_score_fn, | |||
| by_text=by_text, | |||
| ) | |||
| def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None: | |||
| """Delete by vector IDs. | |||
| Args: | |||
| ids: List of ids to delete. | |||
| """ | |||
| if ids is None: | |||
| raise ValueError("No ids provided to delete.") | |||
| # TODO: Check if this can be done in bulk | |||
| for id in ids: | |||
| self._client.data_object.delete(uuid=id) | |||
| @@ -12,6 +12,21 @@ dataset_fields = { | |||
| 'created_at': TimestampField, | |||
| } | |||
| reranking_model_fields = { | |||
| 'reranking_provider_name': fields.String, | |||
| 'reranking_model_name': fields.String | |||
| } | |||
| dataset_retrieval_model_fields = { | |||
| 'search_method': fields.String, | |||
| 'reranking_enable': fields.Boolean, | |||
| 'reranking_model': fields.Nested(reranking_model_fields), | |||
| 'top_k': fields.Integer, | |||
| 'score_threshold_enable': fields.Boolean, | |||
| 'score_threshold': fields.Float | |||
| } | |||
| dataset_detail_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| @@ -29,7 +44,8 @@ dataset_detail_fields = { | |||
| 'updated_at': TimestampField, | |||
| 'embedding_model': fields.String, | |||
| 'embedding_model_provider': fields.String, | |||
| 'embedding_available': fields.Boolean | |||
| 'embedding_available': fields.Boolean, | |||
| 'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields) | |||
| } | |||
| dataset_query_detail_fields = { | |||
| @@ -41,3 +57,5 @@ dataset_query_detail_fields = { | |||
| "created_by": fields.String, | |||
| "created_at": TimestampField | |||
| } | |||
| @@ -0,0 +1,43 @@ | |||
| """add-dataset-retrival-model | |||
| Revision ID: fca025d3b60f | |||
| Revises: b3a09c049e8e | |||
| Create Date: 2023-11-03 13:08:23.246396 | |||
| """ | |||
| from alembic import op | |||
| import sqlalchemy as sa | |||
| from sqlalchemy.dialects import postgresql | |||
| # revision identifiers, used by Alembic. | |||
| revision = 'fca025d3b60f' | |||
| down_revision = '8fe468ba0ca5' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| op.drop_table('sessions') | |||
| with op.batch_alter_table('datasets', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('retrieval_model', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) | |||
| batch_op.create_index('retrieval_model_idx', ['retrieval_model'], unique=False, postgresql_using='gin') | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('datasets', schema=None) as batch_op: | |||
| batch_op.drop_index('retrieval_model_idx', postgresql_using='gin') | |||
| batch_op.drop_column('retrieval_model') | |||
| op.create_table('sessions', | |||
| sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False), | |||
| sa.Column('session_id', sa.VARCHAR(length=255), autoincrement=False, nullable=True), | |||
| sa.Column('data', postgresql.BYTEA(), autoincrement=False, nullable=True), | |||
| sa.Column('expiry', postgresql.TIMESTAMP(), autoincrement=False, nullable=True), | |||
| sa.PrimaryKeyConstraint('id', name='sessions_pkey'), | |||
| sa.UniqueConstraint('session_id', name='sessions_session_id_key') | |||
| ) | |||
| # ### end Alembic commands ### | |||
| @@ -3,7 +3,7 @@ import pickle | |||
| from json import JSONDecodeError | |||
| from sqlalchemy import func | |||
| from sqlalchemy.dialects.postgresql import UUID | |||
| from sqlalchemy.dialects.postgresql import UUID, JSONB | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| @@ -15,6 +15,7 @@ class Dataset(db.Model): | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='dataset_pkey'), | |||
| db.Index('dataset_tenant_idx', 'tenant_id'), | |||
| db.Index('retrieval_model_idx', "retrieval_model", postgresql_using='gin') | |||
| ) | |||
| INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy'] | |||
| @@ -39,7 +40,7 @@ class Dataset(db.Model): | |||
| embedding_model = db.Column(db.String(255), nullable=True) | |||
| embedding_model_provider = db.Column(db.String(255), nullable=True) | |||
| collection_binding_id = db.Column(UUID, nullable=True) | |||
| retrieval_model = db.Column(JSONB, nullable=True) | |||
| @property | |||
| def dataset_keyword_table(self): | |||
| @@ -93,6 +94,20 @@ class Dataset(db.Model): | |||
| return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \ | |||
| .filter(Document.dataset_id == self.id).scalar() | |||
| @property | |||
| def retrieval_model_dict(self): | |||
| default_retrieval_model = { | |||
| 'search_method': 'semantic_search', | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| 'reranking_model_name': '' | |||
| }, | |||
| 'top_k': 2, | |||
| 'score_threshold_enable': False | |||
| } | |||
| return self.retrieval_model if self.retrieval_model else default_retrieval_model | |||
| class DatasetProcessRule(db.Model): | |||
| __tablename__ = 'dataset_process_rules' | |||
| @@ -120,7 +135,7 @@ class DatasetProcessRule(db.Model): | |||
| ], | |||
| 'segmentation': { | |||
| 'delimiter': '\n', | |||
| 'max_tokens': 1000 | |||
| 'max_tokens': 512 | |||
| } | |||
| } | |||
| @@ -462,4 +477,3 @@ class DatasetCollectionBinding(db.Model): | |||
| model_name = db.Column(db.String(40), nullable=False) | |||
| collection_name = db.Column(db.String(64), nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| @@ -160,7 +160,13 @@ class AppModelConfig(db.Model): | |||
| @property | |||
| def dataset_configs_dict(self) -> dict: | |||
| return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}} | |||
| if self.dataset_configs: | |||
| dataset_configs = json.loads(self.dataset_configs) | |||
| if 'retrieval_model' not in dataset_configs: | |||
| return {'retrieval_model': 'single'} | |||
| else: | |||
| return dataset_configs | |||
| return {'retrieval_model': 'single'} | |||
| @property | |||
| def file_upload_dict(self) -> dict: | |||
| @@ -23,7 +23,6 @@ boto3==1.28.17 | |||
| tenacity==8.2.2 | |||
| cachetools~=5.3.0 | |||
| weaviate-client~=3.21.0 | |||
| qdrant_client~=1.1.6 | |||
| mailchimp-transactional~=1.0.50 | |||
| scikit-learn==1.2.2 | |||
| sentry-sdk[flask]~=1.21.1 | |||
| @@ -53,4 +52,6 @@ xinference-client~=0.5.4 | |||
| safetensors==0.3.2 | |||
| zhipuai==1.0.7 | |||
| werkzeug==2.3.7 | |||
| pymilvus==2.3.0 | |||
| pymilvus==2.3.0 | |||
| qdrant-client==1.6.4 | |||
| cohere~=4.32 | |||
| @@ -470,7 +470,16 @@ class AppModelConfigService: | |||
| # dataset_configs | |||
| if 'dataset_configs' not in config or not config["dataset_configs"]: | |||
| config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}} | |||
| config["dataset_configs"] = {'retrieval_model': 'single'} | |||
| if not isinstance(config["dataset_configs"], dict): | |||
| raise ValueError("dataset_configs must be of object type") | |||
| if config["dataset_configs"]['retrieval_model'] == 'multiple': | |||
| if not config["dataset_configs"]['reranking_model']: | |||
| raise ValueError("reranking_model has not been set") | |||
| if not isinstance(config["dataset_configs"]['reranking_model'], dict): | |||
| raise ValueError("reranking_model must be of object type") | |||
| if not isinstance(config["dataset_configs"], dict): | |||
| raise ValueError("dataset_configs must be of object type") | |||
| @@ -173,6 +173,9 @@ class DatasetService: | |||
| filtered_data['updated_by'] = user.id | |||
| filtered_data['updated_at'] = datetime.datetime.now() | |||
| # update Retrieval model | |||
| filtered_data['retrieval_model'] = data['retrieval_model'] | |||
| dataset.query.filter_by(id=dataset_id).update(filtered_data) | |||
| db.session.commit() | |||
| @@ -473,7 +476,19 @@ class DocumentService: | |||
| embedding_model.name | |||
| ) | |||
| dataset.collection_binding_id = dataset_collection_binding.id | |||
| if not dataset.retrieval_model: | |||
| default_retrieval_model = { | |||
| 'search_method': 'semantic_search', | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| 'reranking_model_name': '' | |||
| }, | |||
| 'top_k': 2, | |||
| 'score_threshold_enable': False | |||
| } | |||
| dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model | |||
| documents = [] | |||
| batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) | |||
| @@ -733,6 +748,7 @@ class DocumentService: | |||
| raise ValueError(f"All your documents have overed limit {tenant_document_count}.") | |||
| embedding_model = None | |||
| dataset_collection_binding_id = None | |||
| retrieval_model = None | |||
| if document_data['indexing_technique'] == 'high_quality': | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| tenant_id=tenant_id | |||
| @@ -742,6 +758,20 @@ class DocumentService: | |||
| embedding_model.name | |||
| ) | |||
| dataset_collection_binding_id = dataset_collection_binding.id | |||
| if 'retrieval_model' in document_data and document_data['retrieval_model']: | |||
| retrieval_model = document_data['retrieval_model'] | |||
| else: | |||
| default_retrieval_model = { | |||
| 'search_method': 'semantic_search', | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| 'reranking_model_name': '' | |||
| }, | |||
| 'top_k': 2, | |||
| 'score_threshold_enable': False | |||
| } | |||
| retrieval_model = default_retrieval_model | |||
| # save dataset | |||
| dataset = Dataset( | |||
| tenant_id=tenant_id, | |||
| @@ -751,7 +781,8 @@ class DocumentService: | |||
| created_by=account.id, | |||
| embedding_model=embedding_model.name if embedding_model else None, | |||
| embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None, | |||
| collection_binding_id=dataset_collection_binding_id | |||
| collection_binding_id=dataset_collection_binding_id, | |||
| retrieval_model=retrieval_model | |||
| ) | |||
| db.session.add(dataset) | |||
| @@ -768,7 +799,7 @@ class DocumentService: | |||
| return dataset, documents, batch | |||
| @classmethod | |||
| def document_create_args_validate(cls, args: dict): | |||
| def document_create_args_validate(cls, args: dict): | |||
| if 'original_document_id' not in args or not args['original_document_id']: | |||
| DocumentService.data_source_args_validate(args) | |||
| DocumentService.process_rule_args_validate(args) | |||
| @@ -1,4 +1,6 @@ | |||
| import json | |||
| import logging | |||
| import threading | |||
| import time | |||
| from typing import List | |||
| @@ -9,16 +11,26 @@ from langchain.schema import Document | |||
| from sklearn.manifold import TSNE | |||
| from core.embedding.cached_embedding import CacheEmbedding | |||
| from core.index.vector_index.vector_index import VectorIndex | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.dataset import Dataset, DocumentSegment, DatasetQuery | |||
| from services.retrieval_service import RetrievalService | |||
| default_retrieval_model = { | |||
| 'search_method': 'semantic_search', | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| 'reranking_model_name': '' | |||
| }, | |||
| 'top_k': 2, | |||
| 'score_threshold_enable': False | |||
| } | |||
| class HitTestingService: | |||
| @classmethod | |||
| def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict: | |||
| def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict: | |||
| if dataset.available_document_count == 0 or dataset.available_segment_count == 0: | |||
| return { | |||
| "query": { | |||
| @@ -28,31 +40,68 @@ class HitTestingService: | |||
| "records": [] | |||
| } | |||
| start = time.perf_counter() | |||
| # get retrieval model , if the model is not setting , using default | |||
| if not retrieval_model: | |||
| retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | |||
| # get embedding model | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| tenant_id=dataset.tenant_id, | |||
| model_provider_name=dataset.embedding_model_provider, | |||
| model_name=dataset.embedding_model | |||
| ) | |||
| embeddings = CacheEmbedding(embedding_model) | |||
| vector_index = VectorIndex( | |||
| dataset=dataset, | |||
| config=current_app.config, | |||
| embeddings=embeddings | |||
| ) | |||
| all_documents = [] | |||
| threads = [] | |||
| # retrieval_model source with semantic | |||
| if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': | |||
| embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset': dataset, | |||
| 'query': query, | |||
| 'top_k': retrieval_model['top_k'], | |||
| 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, | |||
| 'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None, | |||
| 'all_documents': all_documents, | |||
| 'search_method': retrieval_model['search_method'], | |||
| 'embeddings': embeddings | |||
| }) | |||
| threads.append(embedding_thread) | |||
| embedding_thread.start() | |||
| # retrieval source with full text | |||
| if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': | |||
| full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset': dataset, | |||
| 'query': query, | |||
| 'search_method': retrieval_model['search_method'], | |||
| 'embeddings': embeddings, | |||
| 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, | |||
| 'top_k': retrieval_model['top_k'], | |||
| 'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None, | |||
| 'all_documents': all_documents | |||
| }) | |||
| threads.append(full_text_index_thread) | |||
| full_text_index_thread.start() | |||
| for thread in threads: | |||
| thread.join() | |||
| if retrieval_model['search_method'] == 'hybrid_search': | |||
| hybrid_rerank = ModelFactory.get_reranking_model( | |||
| tenant_id=dataset.tenant_id, | |||
| model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'], | |||
| model_name=retrieval_model['reranking_model']['reranking_model_name'] | |||
| ) | |||
| all_documents = hybrid_rerank.rerank(query, all_documents, | |||
| retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, | |||
| retrieval_model['top_k']) | |||
| start = time.perf_counter() | |||
| documents = vector_index.search( | |||
| query, | |||
| search_type='similarity_score_threshold', | |||
| search_kwargs={ | |||
| 'k': 10, | |||
| 'filter': { | |||
| 'group_id': [dataset.id] | |||
| } | |||
| } | |||
| ) | |||
| end = time.perf_counter() | |||
| logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") | |||
| @@ -67,7 +116,7 @@ class HitTestingService: | |||
| db.session.add(dataset_query) | |||
| db.session.commit() | |||
| return cls.compact_retrieve_response(dataset, embeddings, query, documents) | |||
| return cls.compact_retrieve_response(dataset, embeddings, query, all_documents) | |||
| @classmethod | |||
| def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: List[Document]): | |||
| @@ -99,7 +148,7 @@ class HitTestingService: | |||
| record = { | |||
| "segment": segment, | |||
| "score": document.metadata['score'], | |||
| "score": document.metadata.get('score', None), | |||
| "tsne_position": tsne_position_data[i] | |||
| } | |||
| @@ -136,3 +185,11 @@ class HitTestingService: | |||
| tsne_position_data.append({'x': float(data_tsne[i][0]), 'y': float(data_tsne[i][1])}) | |||
| return tsne_position_data | |||
| @classmethod | |||
| def hit_testing_args_check(cls, args): | |||
| query = args['query'] | |||
| if not query or len(query) > 250: | |||
| raise ValueError('Query is required and cannot exceed 250 characters') | |||
| @@ -0,0 +1,88 @@ | |||
| from typing import Optional | |||
| from flask import current_app, Flask | |||
| from langchain.embeddings.base import Embeddings | |||
| from core.index.vector_index.vector_index import VectorIndex | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from models.dataset import Dataset | |||
| default_retrieval_model = { | |||
| 'search_method': 'semantic_search', | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| 'reranking_model_name': '' | |||
| }, | |||
| 'top_k': 2, | |||
| 'score_threshold_enable': False | |||
| } | |||
| class RetrievalService: | |||
| @classmethod | |||
| def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str, | |||
| top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], | |||
| all_documents: list, search_method: str, embeddings: Embeddings): | |||
| with flask_app.app_context(): | |||
| vector_index = VectorIndex( | |||
| dataset=dataset, | |||
| config=current_app.config, | |||
| embeddings=embeddings | |||
| ) | |||
| documents = vector_index.search( | |||
| query, | |||
| search_type='similarity_score_threshold', | |||
| search_kwargs={ | |||
| 'k': top_k, | |||
| 'score_threshold': score_threshold, | |||
| 'filter': { | |||
| 'group_id': [dataset.id] | |||
| } | |||
| } | |||
| ) | |||
| if documents: | |||
| if reranking_model and search_method == 'semantic_search': | |||
| rerank = ModelFactory.get_reranking_model( | |||
| tenant_id=dataset.tenant_id, | |||
| model_provider_name=reranking_model['reranking_provider_name'], | |||
| model_name=reranking_model['reranking_model_name'] | |||
| ) | |||
| all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents))) | |||
| else: | |||
| all_documents.extend(documents) | |||
| @classmethod | |||
| def full_text_index_search(cls, flask_app: Flask, dataset: Dataset, query: str, | |||
| top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], | |||
| all_documents: list, search_method: str, embeddings: Embeddings): | |||
| with flask_app.app_context(): | |||
| vector_index = VectorIndex( | |||
| dataset=dataset, | |||
| config=current_app.config, | |||
| embeddings=embeddings | |||
| ) | |||
| documents = vector_index.search_by_full_text_index( | |||
| query, | |||
| search_type='similarity_score_threshold', | |||
| top_k=top_k | |||
| ) | |||
| if documents: | |||
| if reranking_model and search_method == 'full_text_search': | |||
| rerank = ModelFactory.get_reranking_model( | |||
| tenant_id=dataset.tenant_id, | |||
| model_provider_name=reranking_model['reranking_provider_name'], | |||
| model_name=reranking_model['reranking_model_name'] | |||
| ) | |||
| all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents))) | |||
| else: | |||
| all_documents.extend(documents) | |||