| import uuid | import uuid | ||||
| import click | import click | ||||
| import qdrant_client | |||||
| from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType | |||||
| from tqdm import tqdm | from tqdm import tqdm | ||||
| from flask import current_app, Flask | from flask import current_app, Flask | ||||
| from langchain.embeddings import OpenAIEmbeddings | from langchain.embeddings import OpenAIEmbeddings | ||||
| click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green')) | click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green')) | ||||
| @click.command('add-qdrant-full-text-index', help='add qdrant full text index') | |||||
| def add_qdrant_full_text_index(): | |||||
| click.echo(click.style('Start add full text index.', fg='green')) | |||||
| binds = db.session.query(DatasetCollectionBinding).all() | |||||
| if binds and current_app.config['VECTOR_STORE'] == 'qdrant': | |||||
| qdrant_url = current_app.config['QDRANT_URL'] | |||||
| qdrant_api_key = current_app.config['QDRANT_API_KEY'] | |||||
| client = qdrant_client.QdrantClient( | |||||
| qdrant_url, | |||||
| api_key=qdrant_api_key, # For Qdrant Cloud, None for local instance | |||||
| ) | |||||
| for bind in binds: | |||||
| try: | |||||
| text_index_params = TextIndexParams( | |||||
| type=TextIndexType.TEXT, | |||||
| tokenizer=TokenizerType.MULTILINGUAL, | |||||
| min_token_len=2, | |||||
| max_token_len=20, | |||||
| lowercase=True | |||||
| ) | |||||
| client.create_payload_index(bind.collection_name, 'page_content', | |||||
| field_schema=text_index_params) | |||||
| except Exception as e: | |||||
| click.echo( | |||||
| click.style('Create full text index error: {} {}'.format(e.__class__.__name__, str(e)), | |||||
| fg='red')) | |||||
| click.echo( | |||||
| click.style( | |||||
| 'Congratulations! add collection {} full text index successful.'.format(bind.collection_name), | |||||
| fg='green')) | |||||
| def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list): | def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list): | ||||
| with flask_app.app_context(): | with flask_app.app_context(): | ||||
| try: | try: | ||||
| pbar.update(len(data_batch)) | pbar.update(len(data_batch)) | ||||
| @click.command('migrate_default_input_to_dataset_query_variable') | @click.command('migrate_default_input_to_dataset_query_variable') | ||||
| @click.option("--batch-size", default=500, help="Number of records to migrate in each batch.") | @click.option("--batch-size", default=500, help="Number of records to migrate in each batch.") | ||||
| def migrate_default_input_to_dataset_query_variable(batch_size): | def migrate_default_input_to_dataset_query_variable(batch_size): | ||||
| click.secho("Starting...", fg='green') | click.secho("Starting...", fg='green') | ||||
| total_records = db.session.query(AppModelConfig) \ | total_records = db.session.query(AppModelConfig) \ | ||||
| .filter(App.mode == 'completion') \ | .filter(App.mode == 'completion') \ | ||||
| .filter(AppModelConfig.dataset_query_variable == None) \ | .filter(AppModelConfig.dataset_query_variable == None) \ | ||||
| .count() | .count() | ||||
| if total_records == 0: | if total_records == 0: | ||||
| click.secho("No data to migrate.", fg='green') | click.secho("No data to migrate.", fg='green') | ||||
| return | return | ||||
| num_batches = (total_records + batch_size - 1) // batch_size | num_batches = (total_records + batch_size - 1) // batch_size | ||||
| with tqdm(total=total_records, desc="Migrating Data") as pbar: | with tqdm(total=total_records, desc="Migrating Data") as pbar: | ||||
| for i in range(num_batches): | for i in range(num_batches): | ||||
| offset = i * batch_size | offset = i * batch_size | ||||
| for form in user_input_form: | for form in user_input_form: | ||||
| paragraph = form.get('paragraph') | paragraph = form.get('paragraph') | ||||
| if paragraph \ | if paragraph \ | ||||
| and paragraph.get('variable') == 'query': | |||||
| data.dataset_query_variable = 'query' | |||||
| break | |||||
| and paragraph.get('variable') == 'query': | |||||
| data.dataset_query_variable = 'query' | |||||
| break | |||||
| if paragraph \ | if paragraph \ | ||||
| and paragraph.get('variable') == 'default_input': | |||||
| data.dataset_query_variable = 'default_input' | |||||
| break | |||||
| and paragraph.get('variable') == 'default_input': | |||||
| data.dataset_query_variable = 'default_input' | |||||
| break | |||||
| db.session.commit() | db.session.commit() | ||||
| click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}", | click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}", | ||||
| fg='red') | fg='red') | ||||
| continue | continue | ||||
| click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green') | click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green') | ||||
| pbar.update(len(data_batch)) | pbar.update(len(data_batch)) | ||||
| app.cli.add_command(update_app_model_configs) | app.cli.add_command(update_app_model_configs) | ||||
| app.cli.add_command(normalization_collections) | app.cli.add_command(normalization_collections) | ||||
| app.cli.add_command(migrate_default_input_to_dataset_query_variable) | app.cli.add_command(migrate_default_input_to_dataset_query_variable) | ||||
| app.cli.add_command(add_qdrant_full_text_index) |
| help='Invalid indexing technique.') | help='Invalid indexing technique.') | ||||
| parser.add_argument('permission', type=str, location='json', choices=( | parser.add_argument('permission', type=str, location='json', choices=( | ||||
| 'only_me', 'all_team_members'), help='Invalid permission.') | 'only_me', 'all_team_members'), help='Invalid permission.') | ||||
| parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.') | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| # The role of the current user in the ta table must be admin or owner | # The role of the current user in the ta table must be admin or owner | ||||
| class DatasetApiDeleteApi(Resource): | class DatasetApiDeleteApi(Resource): | ||||
| resource_type = 'dataset' | resource_type = 'dataset' | ||||
| @setup_required | @setup_required | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| } | } | ||||
| class DatasetRetrievalSettingApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def get(self): | |||||
| vector_type = current_app.config['VECTOR_STORE'] | |||||
| if vector_type == 'milvus': | |||||
| return { | |||||
| 'retrieval_method': [ | |||||
| 'semantic_search' | |||||
| ] | |||||
| } | |||||
| elif vector_type == 'qdrant' or vector_type == 'weaviate': | |||||
| return { | |||||
| 'retrieval_method': [ | |||||
| 'semantic_search', 'full_text_search', 'hybrid_search' | |||||
| ] | |||||
| } | |||||
| else: | |||||
| raise ValueError("Unsupported vector db type.") | |||||
| class DatasetRetrievalSettingMockApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def get(self, vector_type): | |||||
| if vector_type == 'milvus': | |||||
| return { | |||||
| 'retrieval_method': [ | |||||
| 'semantic_search' | |||||
| ] | |||||
| } | |||||
| elif vector_type == 'qdrant' or vector_type == 'weaviate': | |||||
| return { | |||||
| 'retrieval_method': [ | |||||
| 'semantic_search', 'full_text_search', 'hybrid_search' | |||||
| ] | |||||
| } | |||||
| else: | |||||
| raise ValueError("Unsupported vector db type.") | |||||
| api.add_resource(DatasetListApi, '/datasets') | api.add_resource(DatasetListApi, '/datasets') | ||||
| api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>') | api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>') | ||||
| api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries') | api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries') | ||||
| api.add_resource(DatasetApiKeyApi, '/datasets/api-keys') | api.add_resource(DatasetApiKeyApi, '/datasets/api-keys') | ||||
| api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>') | api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>') | ||||
| api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') | api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') | ||||
| api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') | |||||
| api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>') |
| parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | ||||
| parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, | parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, | ||||
| location='json') | location='json') | ||||
| parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, | |||||
| location='json') | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| if not dataset.indexing_technique and not args['indexing_technique']: | if not dataset.indexing_technique and not args['indexing_technique']: | ||||
| parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | ||||
| parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, | parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, | ||||
| location='json') | location='json') | ||||
| parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, | |||||
| location='json') | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| if args['indexing_technique'] == 'high_quality': | if args['indexing_technique'] == 'high_quality': | ||||
| try: | try: |
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument('query', type=str, location='json') | parser.add_argument('query', type=str, location='json') | ||||
| parser.add_argument('retrieval_model', type=dict, required=False, location='json') | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| query = args['query'] | |||||
| if not query or len(query) > 250: | |||||
| raise ValueError('Query is required and cannot exceed 250 characters') | |||||
| HitTestingService.hit_testing_args_check(args) | |||||
| try: | try: | ||||
| response = HitTestingService.retrieve( | response = HitTestingService.retrieve( | ||||
| dataset=dataset, | dataset=dataset, | ||||
| query=query, | |||||
| query=args['query'], | |||||
| account=current_user, | account=current_user, | ||||
| limit=10, | |||||
| retrieval_model=args['retrieval_model'], | |||||
| limit=10 | |||||
| ) | ) | ||||
| return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)} | return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)} |
| def get(self): | def get(self): | ||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | parser.add_argument('model_type', type=str, required=True, nullable=False, | ||||
| choices=['text-generation', 'embeddings', 'speech2text'], location='args') | |||||
| choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args') | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| tenant_id = current_user.current_tenant_id | tenant_id = current_user.current_tenant_id | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def post(self): | def post(self): | ||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument('model_name', type=str, required=True, nullable=False, location='json') | |||||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||||
| choices=['text-generation', 'embeddings', 'speech2text'], location='json') | |||||
| parser.add_argument('provider_name', type=str, required=True, nullable=False, location='json') | |||||
| parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json') | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| provider_service = ProviderService() | provider_service = ProviderService() | ||||
| provider_service.update_default_model_of_model_type( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| model_type=args['model_type'], | |||||
| provider_name=args['provider_name'], | |||||
| model_name=args['model_name'] | |||||
| ) | |||||
| model_settings = args['model_settings'] | |||||
| for model_setting in model_settings: | |||||
| provider_service.update_default_model_of_model_type( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| model_type=model_setting['model_type'], | |||||
| provider_name=model_setting['provider_name'], | |||||
| model_name=model_setting['model_name'] | |||||
| ) | |||||
| return {'result': 'success'} | return {'result': 'success'} | ||||
| location='json') | location='json') | ||||
| parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, | parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, | ||||
| location='json') | location='json') | ||||
| parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, | |||||
| location='json') | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| dataset_id = str(dataset_id) | dataset_id = str(dataset_id) | ||||
| tenant_id = str(tenant_id) | tenant_id = str(tenant_id) | ||||
| parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | ||||
| parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, | parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, | ||||
| location='json') | location='json') | ||||
| parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, | |||||
| location='json') | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| dataset_id = str(dataset_id) | dataset_id = str(dataset_id) | ||||
| tenant_id = str(tenant_id) | tenant_id = str(tenant_id) |
| from core.model_providers.models.entity.message import to_prompt_messages | from core.model_providers.models.entity.message import to_prompt_messages | ||||
| from core.model_providers.models.llm.base import BaseLLM | from core.model_providers.models.llm.base import BaseLLM | ||||
| from core.third_party.langchain.llms.fake import FakeLLM | from core.third_party.langchain.llms.fake import FakeLLM | ||||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | |||||
| class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | ||||
| return AgentFinish(return_values={"output": ''}, log='') | return AgentFinish(return_values={"output": ''}, log='') | ||||
| elif len(self.tools) == 1: | elif len(self.tools) == 1: | ||||
| tool = next(iter(self.tools)) | tool = next(iter(self.tools)) | ||||
| tool = cast(DatasetRetrieverTool, tool) | |||||
| rst = tool.run(tool_input={'query': kwargs['input']}) | rst = tool.run(tool_input={'query': kwargs['input']}) | ||||
| # output = '' | # output = '' | ||||
| # rst_json = json.loads(rst) | # rst_json = json.loads(rst) |
| import json | |||||
| from typing import Tuple, List, Any, Union, Sequence, Optional, cast | |||||
| from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent | |||||
| from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message | |||||
| from langchain.callbacks.base import BaseCallbackManager | |||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.prompts.chat import BaseMessagePromptTemplate | |||||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage | |||||
| from langchain.schema.language_model import BaseLanguageModel | |||||
| from langchain.tools import BaseTool | |||||
| from pydantic import root_validator | |||||
| from core.model_providers.models.entity.message import to_prompt_messages | |||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| from core.third_party.langchain.llms.fake import FakeLLM | |||||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | |||||
| class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||||
| """ | |||||
| An Multi Dataset Retrieve Agent driven by Router. | |||||
| """ | |||||
| model_instance: BaseLLM | |||||
| class Config: | |||||
| """Configuration for this pydantic object.""" | |||||
| arbitrary_types_allowed = True | |||||
| @root_validator | |||||
| def validate_llm(cls, values: dict) -> dict: | |||||
| return values | |||||
| def should_use_agent(self, query: str): | |||||
| """ | |||||
| return should use agent | |||||
| :param query: | |||||
| :return: | |||||
| """ | |||||
| return True | |||||
| def plan( | |||||
| self, | |||||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs: Any, | |||||
| ) -> Union[AgentAction, AgentFinish]: | |||||
| """Given input, decided what to do. | |||||
| Args: | |||||
| intermediate_steps: Steps the LLM has taken to date, along with observations | |||||
| **kwargs: User inputs. | |||||
| Returns: | |||||
| Action specifying what tool to use. | |||||
| """ | |||||
| if len(self.tools) == 0: | |||||
| return AgentFinish(return_values={"output": ''}, log='') | |||||
| elif len(self.tools) == 1: | |||||
| tool = next(iter(self.tools)) | |||||
| tool = cast(DatasetRetrieverTool, tool) | |||||
| rst = tool.run(tool_input={'query': kwargs['input']}) | |||||
| # output = '' | |||||
| # rst_json = json.loads(rst) | |||||
| # for item in rst_json: | |||||
| # output += f'{item["content"]}\n' | |||||
| return AgentFinish(return_values={"output": rst}, log=rst) | |||||
| if intermediate_steps: | |||||
| _, observation = intermediate_steps[-1] | |||||
| return AgentFinish(return_values={"output": observation}, log=observation) | |||||
| try: | |||||
| agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs) | |||||
| if isinstance(agent_decision, AgentAction): | |||||
| tool_inputs = agent_decision.tool_input | |||||
| if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs: | |||||
| tool_inputs['query'] = kwargs['input'] | |||||
| agent_decision.tool_input = tool_inputs | |||||
| else: | |||||
| agent_decision.return_values['output'] = '' | |||||
| return agent_decision | |||||
| except Exception as e: | |||||
| new_exception = self.model_instance.handle_exceptions(e) | |||||
| raise new_exception | |||||
| def real_plan( | |||||
| self, | |||||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs: Any, | |||||
| ) -> Union[AgentAction, AgentFinish]: | |||||
| """Given input, decided what to do. | |||||
| Args: | |||||
| intermediate_steps: Steps the LLM has taken to date, along with observations | |||||
| **kwargs: User inputs. | |||||
| Returns: | |||||
| Action specifying what tool to use. | |||||
| """ | |||||
| agent_scratchpad = _format_intermediate_steps(intermediate_steps) | |||||
| selected_inputs = { | |||||
| k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" | |||||
| } | |||||
| full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) | |||||
| prompt = self.prompt.format_prompt(**full_inputs) | |||||
| messages = prompt.to_messages() | |||||
| prompt_messages = to_prompt_messages(messages) | |||||
| result = self.model_instance.run( | |||||
| messages=prompt_messages, | |||||
| functions=self.functions, | |||||
| ) | |||||
| ai_message = AIMessage( | |||||
| content=result.content, | |||||
| additional_kwargs={ | |||||
| 'function_call': result.function_call | |||||
| } | |||||
| ) | |||||
| agent_decision = _parse_ai_message(ai_message) | |||||
| return agent_decision | |||||
| async def aplan( | |||||
| self, | |||||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs: Any, | |||||
| ) -> Union[AgentAction, AgentFinish]: | |||||
| raise NotImplementedError() | |||||
| @classmethod | |||||
| def from_llm_and_tools( | |||||
| cls, | |||||
| model_instance: BaseLLM, | |||||
| tools: Sequence[BaseTool], | |||||
| callback_manager: Optional[BaseCallbackManager] = None, | |||||
| extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, | |||||
| system_message: Optional[SystemMessage] = SystemMessage( | |||||
| content="You are a helpful AI assistant." | |||||
| ), | |||||
| **kwargs: Any, | |||||
| ) -> BaseSingleActionAgent: | |||||
| prompt = cls.create_prompt( | |||||
| extra_prompt_messages=extra_prompt_messages, | |||||
| system_message=system_message, | |||||
| ) | |||||
| return cls( | |||||
| model_instance=model_instance, | |||||
| llm=FakeLLM(response=''), | |||||
| prompt=prompt, | |||||
| tools=tools, | |||||
| callback_manager=callback_manager, | |||||
| **kwargs, | |||||
| ) |
| return AgentFinish(return_values={"output": ''}, log='') | return AgentFinish(return_values={"output": ''}, log='') | ||||
| elif len(self.dataset_tools) == 1: | elif len(self.dataset_tools) == 1: | ||||
| tool = next(iter(self.dataset_tools)) | tool = next(iter(self.dataset_tools)) | ||||
| tool = cast(DatasetRetrieverTool, tool) | |||||
| rst = tool.run(tool_input={'query': kwargs['input']}) | rst = tool.run(tool_input={'query': kwargs['input']}) | ||||
| return AgentFinish(return_values={"output": rst}, log=rst) | return AgentFinish(return_values={"output": rst}, log=rst) | ||||
| from core.helper import moderation | from core.helper import moderation | ||||
| from core.model_providers.error import LLMError | from core.model_providers.error import LLMError | ||||
| from core.model_providers.models.llm.base import BaseLLM | from core.model_providers.models.llm.base import BaseLLM | ||||
| from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool | |||||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | from core.tool.dataset_retriever_tool import DatasetRetrieverTool | ||||
| verbose=True | verbose=True | ||||
| ) | ) | ||||
| elif self.configuration.strategy == PlanningStrategy.ROUTER: | elif self.configuration.strategy == PlanningStrategy.ROUTER: | ||||
| self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] | |||||
| self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)] | |||||
| agent = MultiDatasetRouterAgent.from_llm_and_tools( | agent = MultiDatasetRouterAgent.from_llm_and_tools( | ||||
| model_instance=self.configuration.model_instance, | model_instance=self.configuration.model_instance, | ||||
| tools=self.configuration.tools, | tools=self.configuration.tools, | ||||
| verbose=True | verbose=True | ||||
| ) | ) | ||||
| elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER: | elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER: | ||||
| self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] | |||||
| self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)] | |||||
| agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools( | agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools( | ||||
| model_instance=self.configuration.model_instance, | model_instance=self.configuration.model_instance, | ||||
| tools=self.configuration.tools, | tools=self.configuration.tools, |
| class DatasetIndexToolCallbackHandler: | class DatasetIndexToolCallbackHandler: | ||||
| """Callback handler for dataset tool.""" | """Callback handler for dataset tool.""" | ||||
| def __init__(self, dataset_id: str, conversation_message_task: ConversationMessageTask) -> None: | |||||
| self.dataset_id = dataset_id | |||||
| def __init__(self, conversation_message_task: ConversationMessageTask) -> None: | |||||
| self.conversation_message_task = conversation_message_task | self.conversation_message_task = conversation_message_task | ||||
| def on_tool_end(self, documents: List[Document]) -> None: | def on_tool_end(self, documents: List[Document]) -> None: | ||||
| # add hit count to document segment | # add hit count to document segment | ||||
| db.session.query(DocumentSegment).filter( | db.session.query(DocumentSegment).filter( | ||||
| DocumentSegment.dataset_id == self.dataset_id, | |||||
| DocumentSegment.index_node_id == doc_id | DocumentSegment.index_node_id == doc_id | ||||
| ).update( | ).update( | ||||
| {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, | {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, |
| memory=memory, | memory=memory, | ||||
| rest_tokens=rest_tokens_for_context_and_memory, | rest_tokens=rest_tokens_for_context_and_memory, | ||||
| chain_callback=chain_callback, | chain_callback=chain_callback, | ||||
| tenant_id=app.tenant_id, | |||||
| retriever_from=retriever_from | retriever_from=retriever_from | ||||
| ) | ) | ||||
| from typing import List, Union, Optional | from typing import List, Union, Optional | ||||
| import requests | import requests | ||||
| from langchain.document_loaders import TextLoader, Docx2txtLoader | |||||
| from langchain.document_loaders import TextLoader, Docx2txtLoader, UnstructuredFileLoader, UnstructuredAPIFileLoader | |||||
| from langchain.schema import Document | from langchain.schema import Document | ||||
| from core.data_loader.loader.csv_loader import CSVLoader | from core.data_loader.loader.csv_loader import CSVLoader | ||||
| class FileExtractor: | class FileExtractor: | ||||
| @classmethod | @classmethod | ||||
| def load(cls, upload_file: UploadFile, return_text: bool = False) -> Union[List[Document] | str]: | |||||
| def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document] | str]: | |||||
| with tempfile.TemporaryDirectory() as temp_dir: | with tempfile.TemporaryDirectory() as temp_dir: | ||||
| suffix = Path(upload_file.key).suffix | suffix = Path(upload_file.key).suffix | ||||
| file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" | file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" | ||||
| storage.download(upload_file.key, file_path) | storage.download(upload_file.key, file_path) | ||||
| return cls.load_from_file(file_path, return_text, upload_file) | |||||
| return cls.load_from_file(file_path, return_text, upload_file, is_automatic) | |||||
| @classmethod | @classmethod | ||||
| def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document] | str]: | def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document] | str]: | ||||
| @classmethod | @classmethod | ||||
| def load_from_file(cls, file_path: str, return_text: bool = False, | def load_from_file(cls, file_path: str, return_text: bool = False, | ||||
| upload_file: Optional[UploadFile] = None) -> Union[List[Document] | str]: | |||||
| upload_file: Optional[UploadFile] = None, | |||||
| is_automatic: bool = False) -> Union[List[Document] | str]: | |||||
| input_file = Path(file_path) | input_file = Path(file_path) | ||||
| delimiter = '\n' | delimiter = '\n' | ||||
| file_extension = input_file.suffix.lower() | file_extension = input_file.suffix.lower() | ||||
| if file_extension == '.xlsx': | |||||
| loader = ExcelLoader(file_path) | |||||
| elif file_extension == '.pdf': | |||||
| loader = PdfLoader(file_path, upload_file=upload_file) | |||||
| elif file_extension in ['.md', '.markdown']: | |||||
| loader = MarkdownLoader(file_path, autodetect_encoding=True) | |||||
| elif file_extension in ['.htm', '.html']: | |||||
| loader = HTMLLoader(file_path) | |||||
| elif file_extension == '.docx': | |||||
| loader = Docx2txtLoader(file_path) | |||||
| elif file_extension == '.csv': | |||||
| loader = CSVLoader(file_path, autodetect_encoding=True) | |||||
| if is_automatic: | |||||
| loader = UnstructuredFileLoader( | |||||
| file_path, strategy="hi_res", mode="elements" | |||||
| ) | |||||
| # loader = UnstructuredAPIFileLoader( | |||||
| # file_path=filenames[0], | |||||
| # api_key="FAKE_API_KEY", | |||||
| # ) | |||||
| else: | else: | ||||
| # txt | |||||
| loader = TextLoader(file_path, autodetect_encoding=True) | |||||
| if file_extension == '.xlsx': | |||||
| loader = ExcelLoader(file_path) | |||||
| elif file_extension == '.pdf': | |||||
| loader = PdfLoader(file_path, upload_file=upload_file) | |||||
| elif file_extension in ['.md', '.markdown']: | |||||
| loader = MarkdownLoader(file_path, autodetect_encoding=True) | |||||
| elif file_extension in ['.htm', '.html']: | |||||
| loader = HTMLLoader(file_path) | |||||
| elif file_extension == '.docx': | |||||
| loader = Docx2txtLoader(file_path) | |||||
| elif file_extension == '.csv': | |||||
| loader = CSVLoader(file_path, autodetect_encoding=True) | |||||
| else: | |||||
| # txt | |||||
| loader = TextLoader(file_path, autodetect_encoding=True) | |||||
| return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load() | return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load() |
| def _get_vector_store_class(self) -> type: | def _get_vector_store_class(self) -> type: | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @abstractmethod | |||||
| def search_by_full_text_index( | |||||
| self, query: str, | |||||
| **kwargs: Any | |||||
| ) -> List[Document]: | |||||
| raise NotImplementedError | |||||
| def search( | def search( | ||||
| self, query: str, | self, query: str, | ||||
| **kwargs: Any | **kwargs: Any |
| from typing import Optional, cast | |||||
| from typing import cast, Any, List | |||||
| from langchain.embeddings.base import Embeddings | from langchain.embeddings.base import Embeddings | ||||
| from langchain.schema import Document, BaseRetriever | |||||
| from langchain.vectorstores import VectorStore, milvus | |||||
| from langchain.schema import Document | |||||
| from langchain.vectorstores import VectorStore | |||||
| from pydantic import BaseModel, root_validator | from pydantic import BaseModel, root_validator | ||||
| from core.index.base import BaseIndex | from core.index.base import BaseIndex | ||||
| from core.index.vector_index.base import BaseVectorIndex | from core.index.vector_index.base import BaseVectorIndex | ||||
| from core.vector_store.milvus_vector_store import MilvusVectorStore | from core.vector_store.milvus_vector_store import MilvusVectorStore | ||||
| from core.vector_store.weaviate_vector_store import WeaviateVectorStore | |||||
| from extensions.ext_database import db | |||||
| from models.dataset import Dataset, DatasetCollectionBinding | |||||
| from models.dataset import Dataset | |||||
| class MilvusConfig(BaseModel): | class MilvusConfig(BaseModel): | ||||
| index_params = { | index_params = { | ||||
| 'metric_type': 'IP', | 'metric_type': 'IP', | ||||
| 'index_type': "HNSW", | 'index_type': "HNSW", | ||||
| 'params': {"M": 8, "efConstruction": 64} | |||||
| 'params': {"M": 8, "efConstruction": 64} | |||||
| } | } | ||||
| self._vector_store = MilvusVectorStore.from_documents( | self._vector_store = MilvusVectorStore.from_documents( | ||||
| texts, | texts, | ||||
| ), | ), | ||||
| ], | ], | ||||
| )) | )) | ||||
| def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]: | |||||
| # milvus/zilliz doesn't support bm25 search | |||||
| return [] |
| return True | return True | ||||
| return False | return False | ||||
| def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]: | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| from qdrant_client.http import models | |||||
| return vector_store.similarity_search_by_bm25(models.Filter( | |||||
| must=[ | |||||
| models.FieldCondition( | |||||
| key="group_id", | |||||
| match=models.MatchValue(value=self.dataset.id), | |||||
| ), | |||||
| models.FieldCondition( | |||||
| key="page_content", | |||||
| match=models.MatchText(text=query), | |||||
| ) | |||||
| ], | |||||
| ), kwargs.get('top_k', 2)) |
| from typing import Optional, cast | |||||
| from typing import Optional, cast, Any, List | |||||
| import requests | import requests | ||||
| import weaviate | import weaviate | ||||
| class WeaviateVectorIndex(BaseVectorIndex): | class WeaviateVectorIndex(BaseVectorIndex): | ||||
| def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings): | def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings): | ||||
| super().__init__(dataset, embeddings) | super().__init__(dataset, embeddings) | ||||
| self._client = self._init_client(config) | self._client = self._init_client(config) | ||||
| return True | return True | ||||
| return False | return False | ||||
| def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]: | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs) | |||||
| if not dataset: | if not dataset: | ||||
| raise ValueError("no dataset found") | raise ValueError("no dataset found") | ||||
| # load file | |||||
| text_docs = self._load_data(dataset_document) | |||||
| # get the process rule | # get the process rule | ||||
| processing_rule = db.session.query(DatasetProcessRule). \ | processing_rule = db.session.query(DatasetProcessRule). \ | ||||
| filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ | filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ | ||||
| first() | first() | ||||
| # load file | |||||
| text_docs = self._load_data(dataset_document) | |||||
| # get splitter | # get splitter | ||||
| splitter = self._get_splitter(processing_rule) | splitter = self._get_splitter(processing_rule) | ||||
| "preview": preview_texts | "preview": preview_texts | ||||
| } | } | ||||
| def _load_data(self, dataset_document: DatasetDocument) -> List[Document]: | |||||
| def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> List[Document]: | |||||
| # load file | # load file | ||||
| if dataset_document.data_source_type not in ["upload_file", "notion_import"]: | if dataset_document.data_source_type not in ["upload_file", "notion_import"]: | ||||
| return [] | return [] | ||||
| one_or_none() | one_or_none() | ||||
| if file_detail: | if file_detail: | ||||
| text_docs = FileExtractor.load(file_detail) | |||||
| text_docs = FileExtractor.load(file_detail, is_automatic=False) | |||||
| elif dataset_document.data_source_type == 'notion_import': | elif dataset_document.data_source_type == 'notion_import': | ||||
| loader = NotionLoader.from_document(dataset_document) | loader = NotionLoader.from_document(dataset_document) | ||||
| text_docs = loader.load() | text_docs = loader.load() |
| from core.model_providers.models.entity.model_params import ModelKwargs, ModelType | from core.model_providers.models.entity.model_params import ModelKwargs, ModelType | ||||
| from core.model_providers.models.llm.base import BaseLLM | from core.model_providers.models.llm.base import BaseLLM | ||||
| from core.model_providers.models.moderation.base import BaseModeration | from core.model_providers.models.moderation.base import BaseModeration | ||||
| from core.model_providers.models.reranking.base import BaseReranking | |||||
| from core.model_providers.models.speech2text.base import BaseSpeech2Text | from core.model_providers.models.speech2text.base import BaseSpeech2Text | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.provider import TenantDefaultModel | from models.provider import TenantDefaultModel | ||||
| name=model_name | name=model_name | ||||
| ) | ) | ||||
| @classmethod | |||||
| def get_reranking_model(cls, | |||||
| tenant_id: str, | |||||
| model_provider_name: Optional[str] = None, | |||||
| model_name: Optional[str] = None) -> Optional[BaseReranking]: | |||||
| """ | |||||
| get reranking model. | |||||
| :param tenant_id: a string representing the ID of the tenant. | |||||
| :param model_provider_name: | |||||
| :param model_name: | |||||
| :return: | |||||
| """ | |||||
| if model_provider_name is None and model_name is None: | |||||
| default_model = cls.get_default_model(tenant_id, ModelType.RERANKING) | |||||
| if not default_model: | |||||
| raise LLMBadRequestError(f"Default model is not available. " | |||||
| f"Please configure a Default Reranking Model " | |||||
| f"in the Settings -> Model Provider.") | |||||
| model_provider_name = default_model.provider_name | |||||
| model_name = default_model.model_name | |||||
| # get model provider | |||||
| model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) | |||||
| if not model_provider: | |||||
| raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") | |||||
| # init reranking model | |||||
| model_class = model_provider.get_model_class(model_type=ModelType.RERANKING) | |||||
| return model_class( | |||||
| model_provider=model_provider, | |||||
| name=model_name | |||||
| ) | |||||
| @classmethod | @classmethod | ||||
| def get_speech2text_model(cls, | def get_speech2text_model(cls, | ||||
| tenant_id: str, | tenant_id: str, |
| elif provider_name == 'localai': | elif provider_name == 'localai': | ||||
| from core.model_providers.providers.localai_provider import LocalAIProvider | from core.model_providers.providers.localai_provider import LocalAIProvider | ||||
| return LocalAIProvider | return LocalAIProvider | ||||
| elif provider_name == 'cohere': | |||||
| from core.model_providers.providers.cohere_provider import CohereProvider | |||||
| return CohereProvider | |||||
| else: | else: | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| IMAGE = 'image' | IMAGE = 'image' | ||||
| VIDEO = 'video' | VIDEO = 'video' | ||||
| MODERATION = 'moderation' | MODERATION = 'moderation' | ||||
| RERANKING = 'reranking' | |||||
| @staticmethod | @staticmethod | ||||
| def value_of(value): | def value_of(value): | ||||
| for member in ModelType: | for member in ModelType: |
| from abc import abstractmethod | |||||
| from typing import Any, Optional, List | |||||
| from langchain.schema import Document | |||||
| from core.model_providers.models.base import BaseProviderModel | |||||
| from core.model_providers.models.entity.model_params import ModelType | |||||
| from core.model_providers.providers.base import BaseModelProvider | |||||
| import logging | |||||
| logger = logging.getLogger(__name__) | |||||
| class BaseReranking(BaseProviderModel): | |||||
| name: str | |||||
| type: ModelType = ModelType.RERANKING | |||||
| def __init__(self, model_provider: BaseModelProvider, client: Any, name: str): | |||||
| super().__init__(model_provider, client) | |||||
| self.name = name | |||||
| @property | |||||
| def base_model_name(self) -> str: | |||||
| """ | |||||
| get base model name | |||||
| :return: str | |||||
| """ | |||||
| return self.name | |||||
| @abstractmethod | |||||
| def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||||
| raise NotImplementedError |
| import logging | |||||
| from typing import Optional, List | |||||
| import cohere | |||||
| import openai | |||||
| from langchain.schema import Document | |||||
| from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ | |||||
| LLMRateLimitError, LLMAuthorizationError | |||||
| from core.model_providers.models.reranking.base import BaseReranking | |||||
| from core.model_providers.providers.base import BaseModelProvider | |||||
| class CohereReranking(BaseReranking): | |||||
| def __init__(self, model_provider: BaseModelProvider, name: str): | |||||
| self.credentials = model_provider.get_model_credentials( | |||||
| model_name=name, | |||||
| model_type=self.type | |||||
| ) | |||||
| client = cohere.Client(self.credentials.get('api_key')) | |||||
| super().__init__(model_provider, client, name) | |||||
| def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]: | |||||
| docs = [] | |||||
| doc_id = [] | |||||
| for document in documents: | |||||
| if document.metadata['doc_id'] not in doc_id: | |||||
| doc_id.append(document.metadata['doc_id']) | |||||
| docs.append(document.page_content) | |||||
| results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k) | |||||
| rerank_documents = [] | |||||
| for idx, result in enumerate(results): | |||||
| # format document | |||||
| rerank_document = Document( | |||||
| page_content=result.document['text'], | |||||
| metadata={ | |||||
| "doc_id": documents[result.index].metadata['doc_id'], | |||||
| "doc_hash": documents[result.index].metadata['doc_hash'], | |||||
| "document_id": documents[result.index].metadata['document_id'], | |||||
| "dataset_id": documents[result.index].metadata['dataset_id'], | |||||
| 'score': result.relevance_score | |||||
| } | |||||
| ) | |||||
| # score threshold check | |||||
| if score_threshold is not None: | |||||
| if result.relevance_score >= score_threshold: | |||||
| rerank_documents.append(rerank_document) | |||||
| else: | |||||
| rerank_documents.append(rerank_document) | |||||
| return rerank_documents | |||||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||||
| if isinstance(ex, openai.error.InvalidRequestError): | |||||
| logging.warning("Invalid request to OpenAI API.") | |||||
| return LLMBadRequestError(str(ex)) | |||||
| elif isinstance(ex, openai.error.APIConnectionError): | |||||
| logging.warning("Failed to connect to OpenAI API.") | |||||
| return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex)) | |||||
| elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)): | |||||
| logging.warning("OpenAI service unavailable.") | |||||
| return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex)) | |||||
| elif isinstance(ex, openai.error.RateLimitError): | |||||
| return LLMRateLimitError(str(ex)) | |||||
| elif isinstance(ex, openai.error.AuthenticationError): | |||||
| return LLMAuthorizationError(str(ex)) | |||||
| elif isinstance(ex, openai.error.OpenAIError): | |||||
| return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex)) | |||||
| else: | |||||
| return ex |
| import json | |||||
| from json import JSONDecodeError | |||||
| from typing import Type | |||||
| from langchain.schema import HumanMessage | |||||
| from core.helper import encrypter | |||||
| from core.model_providers.models.base import BaseProviderModel | |||||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode | |||||
| from core.model_providers.models.reranking.cohere_reranking import CohereReranking | |||||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||||
| from models.provider import ProviderType | |||||
| class CohereProvider(BaseModelProvider): | |||||
| @property | |||||
| def provider_name(self): | |||||
| """ | |||||
| Returns the name of a provider. | |||||
| """ | |||||
| return 'cohere' | |||||
| def _get_text_generation_model_mode(self, model_name) -> str: | |||||
| return ModelMode.CHAT.value | |||||
| def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: | |||||
| if model_type == ModelType.RERANKING: | |||||
| return [ | |||||
| { | |||||
| 'id': 'rerank-english-v2.0', | |||||
| 'name': 'rerank-english-v2.0' | |||||
| }, | |||||
| { | |||||
| 'id': 'rerank-multilingual-v2.0', | |||||
| 'name': 'rerank-multilingual-v2.0' | |||||
| } | |||||
| ] | |||||
| else: | |||||
| return [] | |||||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||||
| """ | |||||
| Returns the model class. | |||||
| :param model_type: | |||||
| :return: | |||||
| """ | |||||
| if model_type == ModelType.RERANKING: | |||||
| model_class = CohereReranking | |||||
| else: | |||||
| raise NotImplementedError | |||||
| return model_class | |||||
| def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: | |||||
| """ | |||||
| get model parameter rules. | |||||
| :param model_name: | |||||
| :param model_type: | |||||
| :return: | |||||
| """ | |||||
| return ModelKwargsRules( | |||||
| temperature=KwargRule[float](min=0, max=1, default=0.3, precision=2), | |||||
| top_p=KwargRule[float](min=0, max=0.99, default=0.85, precision=2), | |||||
| presence_penalty=KwargRule[float](enabled=False), | |||||
| frequency_penalty=KwargRule[float](enabled=False), | |||||
| max_tokens=KwargRule[int](enabled=False), | |||||
| ) | |||||
| @classmethod | |||||
| def is_provider_credentials_valid_or_raise(cls, credentials: dict): | |||||
| """ | |||||
| Validates the given credentials. | |||||
| """ | |||||
| if 'api_key' not in credentials: | |||||
| raise CredentialsValidateFailedError('Cohere api_key must be provided.') | |||||
| try: | |||||
| credential_kwargs = { | |||||
| 'api_key': credentials['api_key'], | |||||
| } | |||||
| # todo validate | |||||
| except Exception as ex: | |||||
| raise CredentialsValidateFailedError(str(ex)) | |||||
| @classmethod | |||||
| def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: | |||||
| credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key']) | |||||
| return credentials | |||||
| def get_provider_credentials(self, obfuscated: bool = False) -> dict: | |||||
| if self.provider.provider_type == ProviderType.CUSTOM.value: | |||||
| try: | |||||
| credentials = json.loads(self.provider.encrypted_config) | |||||
| except JSONDecodeError: | |||||
| credentials = { | |||||
| 'api_key': None, | |||||
| } | |||||
| if credentials['api_key']: | |||||
| credentials['api_key'] = encrypter.decrypt_token( | |||||
| self.provider.tenant_id, | |||||
| credentials['api_key'] | |||||
| ) | |||||
| if obfuscated: | |||||
| credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key']) | |||||
| return credentials | |||||
| else: | |||||
| return {} | |||||
| def should_deduct_quota(self): | |||||
| return True | |||||
| @classmethod | |||||
| def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): | |||||
| """ | |||||
| check model credentials valid. | |||||
| :param model_name: | |||||
| :param model_type: | |||||
| :param credentials: | |||||
| """ | |||||
| return | |||||
| @classmethod | |||||
| def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, | |||||
| credentials: dict) -> dict: | |||||
| """ | |||||
| encrypt model credentials for save. | |||||
| :param tenant_id: | |||||
| :param model_name: | |||||
| :param model_type: | |||||
| :param credentials: | |||||
| :return: | |||||
| """ | |||||
| return {} | |||||
| def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: | |||||
| """ | |||||
| get credentials for llm use. | |||||
| :param model_name: | |||||
| :param model_type: | |||||
| :param obfuscated: | |||||
| :return: | |||||
| """ | |||||
| return self.get_provider_credentials(obfuscated) |
| "huggingface_hub", | "huggingface_hub", | ||||
| "xinference", | "xinference", | ||||
| "openllm", | "openllm", | ||||
| "localai" | |||||
| "localai", | |||||
| "cohere" | |||||
| ] | ] |
| { | |||||
| "support_provider_types": [ | |||||
| "custom" | |||||
| ], | |||||
| "system_config": null, | |||||
| "model_flexibility": "fixed" | |||||
| } |
| from typing import Optional | |||||
| import json | |||||
| import threading | |||||
| from typing import Optional, List | |||||
| from flask import Flask | |||||
| from langchain import WikipediaAPIWrapper | from langchain import WikipediaAPIWrapper | ||||
| from langchain.callbacks.manager import Callbacks | from langchain.callbacks.manager import Callbacks | ||||
| from langchain.memory.chat_memory import BaseChatMemory | from langchain.memory.chat_memory import BaseChatMemory | ||||
| from langchain.tools import BaseTool, Tool, WikipediaQueryRun | from langchain.tools import BaseTool, Tool, WikipediaQueryRun | ||||
| from pydantic import BaseModel, Field | from pydantic import BaseModel, Field | ||||
| from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent | |||||
| from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser | |||||
| from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent | |||||
| from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration | from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration | ||||
| from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler | from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler | ||||
| from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler | from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler | ||||
| from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode | from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode | ||||
| from core.model_providers.models.llm.base import BaseLLM | from core.model_providers.models.llm.base import BaseLLM | ||||
| from core.tool.current_datetime_tool import DatetimeTool | from core.tool.current_datetime_tool import DatetimeTool | ||||
| from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool | |||||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | from core.tool.dataset_retriever_tool import DatasetRetrieverTool | ||||
| from core.tool.provider.serpapi_provider import SerpAPIToolProvider | from core.tool.provider.serpapi_provider import SerpAPIToolProvider | ||||
| from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput | from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput | ||||
| from models.dataset import Dataset, DatasetProcessRule | from models.dataset import Dataset, DatasetProcessRule | ||||
| from models.model import AppModelConfig | from models.model import AppModelConfig | ||||
| default_retrieval_model = { | |||||
| 'search_method': 'semantic_search', | |||||
| 'reranking_enable': False, | |||||
| 'reranking_model': { | |||||
| 'reranking_provider_name': '', | |||||
| 'reranking_model_name': '' | |||||
| }, | |||||
| 'top_k': 2, | |||||
| 'score_threshold_enable': False | |||||
| } | |||||
| class OrchestratorRuleParser: | class OrchestratorRuleParser: | ||||
| """Parse the orchestrator rule to entities.""" | """Parse the orchestrator rule to entities.""" | ||||
| self.app_model_config = app_model_config | self.app_model_config = app_model_config | ||||
| def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory], | def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory], | ||||
| rest_tokens: int, chain_callback: MainChainGatherCallbackHandler, | |||||
| rest_tokens: int, chain_callback: MainChainGatherCallbackHandler, tenant_id: str, | |||||
| retriever_from: str = 'dev') -> Optional[AgentExecutor]: | retriever_from: str = 'dev') -> Optional[AgentExecutor]: | ||||
| if not self.app_model_config.agent_mode_dict: | if not self.app_model_config.agent_mode_dict: | ||||
| return None | return None | ||||
| rest_tokens=rest_tokens, | rest_tokens=rest_tokens, | ||||
| return_resource=return_resource, | return_resource=return_resource, | ||||
| retriever_from=retriever_from, | retriever_from=retriever_from, | ||||
| dataset_configs=dataset_configs | |||||
| dataset_configs=dataset_configs, | |||||
| tenant_id=tenant_id | |||||
| ) | ) | ||||
| if len(tools) == 0: | if len(tools) == 0: | ||||
| return chain | return chain | ||||
| def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]: | |||||
| def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]: | |||||
| """ | """ | ||||
| Convert app agent tool configs to tools | Convert app agent tool configs to tools | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| tools = [] | tools = [] | ||||
| dataset_tools = [] | |||||
| for tool_config in tool_configs: | for tool_config in tool_configs: | ||||
| tool_type = list(tool_config.keys())[0] | tool_type = list(tool_config.keys())[0] | ||||
| tool_val = list(tool_config.values())[0] | tool_val = list(tool_config.values())[0] | ||||
| tool = None | tool = None | ||||
| if tool_type == "dataset": | if tool_type == "dataset": | ||||
| tool = self.to_dataset_retriever_tool(tool_config=tool_val, **kwargs) | |||||
| dataset_tools.append(tool_config) | |||||
| elif tool_type == "web_reader": | elif tool_type == "web_reader": | ||||
| tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs) | tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs) | ||||
| elif tool_type == "google_search": | elif tool_type == "google_search": | ||||
| else: | else: | ||||
| tool.callbacks = callbacks | tool.callbacks = callbacks | ||||
| tools.append(tool) | tools.append(tool) | ||||
| # format dataset tool | |||||
| if len(dataset_tools) > 0: | |||||
| dataset_retriever_tools = self.to_dataset_retriever_tool(tool_configs=dataset_tools, **kwargs) | |||||
| if dataset_retriever_tools: | |||||
| tools.extend(dataset_retriever_tools) | |||||
| return tools | return tools | ||||
| def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask, | |||||
| dataset_configs: dict, rest_tokens: int, | |||||
| def to_dataset_retriever_tool(self, tool_configs: List, conversation_message_task: ConversationMessageTask, | |||||
| return_resource: bool = False, retriever_from: str = 'dev', | return_resource: bool = False, retriever_from: str = 'dev', | ||||
| **kwargs) \ | **kwargs) \ | ||||
| -> Optional[BaseTool]: | |||||
| -> Optional[List[BaseTool]]: | |||||
| """ | """ | ||||
| A dataset tool is a tool that can be used to retrieve information from a dataset | A dataset tool is a tool that can be used to retrieve information from a dataset | ||||
| :param rest_tokens: | |||||
| :param tool_config: | |||||
| :param dataset_configs: | |||||
| :param tool_configs: | |||||
| :param conversation_message_task: | :param conversation_message_task: | ||||
| :param return_resource: | :param return_resource: | ||||
| :param retriever_from: | :param retriever_from: | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| # get dataset from dataset id | |||||
| dataset = db.session.query(Dataset).filter( | |||||
| Dataset.tenant_id == self.tenant_id, | |||||
| Dataset.id == tool_config.get("id") | |||||
| ).first() | |||||
| if not dataset: | |||||
| return None | |||||
| if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0: | |||||
| return None | |||||
| top_k = dataset_configs.get("top_k", 2) | |||||
| # dynamically adjust top_k when the remaining token number is not enough to support top_k | |||||
| top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens) | |||||
| dataset_configs = kwargs['dataset_configs'] | |||||
| retrieval_model = dataset_configs.get('retrieval_model', 'single') | |||||
| tools = [] | |||||
| dataset_ids = [] | |||||
| tenant_id = None | |||||
| for tool_config in tool_configs: | |||||
| # get dataset from dataset id | |||||
| dataset = db.session.query(Dataset).filter( | |||||
| Dataset.tenant_id == self.tenant_id, | |||||
| Dataset.id == tool_config.get('dataset').get("id") | |||||
| ).first() | |||||
| score_threshold = None | |||||
| score_threshold_config = dataset_configs.get("score_threshold") | |||||
| if score_threshold_config and score_threshold_config.get("enable"): | |||||
| score_threshold = score_threshold_config.get("value") | |||||
| if not dataset: | |||||
| return None | |||||
| tool = DatasetRetrieverTool.from_dataset( | |||||
| dataset=dataset, | |||||
| top_k=top_k, | |||||
| score_threshold=score_threshold, | |||||
| callbacks=[DatasetToolCallbackHandler(conversation_message_task)], | |||||
| conversation_message_task=conversation_message_task, | |||||
| return_resource=return_resource, | |||||
| retriever_from=retriever_from | |||||
| ) | |||||
| if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0: | |||||
| return None | |||||
| dataset_ids.append(dataset.id) | |||||
| if retrieval_model == 'single': | |||||
| retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | |||||
| top_k = retrieval_model['top_k'] | |||||
| # dynamically adjust top_k when the remaining token number is not enough to support top_k | |||||
| # top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens) | |||||
| score_threshold = None | |||||
| score_threshold_enable = retrieval_model.get("score_threshold_enable") | |||||
| if score_threshold_enable: | |||||
| score_threshold = retrieval_model.get("score_threshold") | |||||
| tool = DatasetRetrieverTool.from_dataset( | |||||
| dataset=dataset, | |||||
| top_k=top_k, | |||||
| score_threshold=score_threshold, | |||||
| callbacks=[DatasetToolCallbackHandler(conversation_message_task)], | |||||
| conversation_message_task=conversation_message_task, | |||||
| return_resource=return_resource, | |||||
| retriever_from=retriever_from | |||||
| ) | |||||
| tools.append(tool) | |||||
| if retrieval_model == 'multiple': | |||||
| tool = DatasetMultiRetrieverTool.from_dataset( | |||||
| dataset_ids=dataset_ids, | |||||
| tenant_id=kwargs['tenant_id'], | |||||
| top_k=dataset_configs.get('top_k', 2), | |||||
| score_threshold=dataset_configs.get('score_threshold', 0.5) if dataset_configs.get('score_threshold_enable', False) else None, | |||||
| callbacks=[DatasetToolCallbackHandler(conversation_message_task)], | |||||
| conversation_message_task=conversation_message_task, | |||||
| return_resource=return_resource, | |||||
| retriever_from=retriever_from, | |||||
| reranking_provider_name=dataset_configs.get('reranking_model').get('reranking_provider_name'), | |||||
| reranking_model_name=dataset_configs.get('reranking_model').get('reranking_model_name') | |||||
| ) | |||||
| tools.append(tool) | |||||
| return tool | |||||
| return tools | |||||
| def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]: | def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]: | ||||
| """ | """ |
| import json | |||||
| import threading | |||||
| from typing import Type, Optional, List | |||||
| from flask import current_app, Flask | |||||
| from langchain.tools import BaseTool | |||||
| from pydantic import Field, BaseModel | |||||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||||
| from core.conversation_message_task import ConversationMessageTask | |||||
| from core.embedding.cached_embedding import CacheEmbedding | |||||
| from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig | |||||
| from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError | |||||
| from core.model_providers.model_factory import ModelFactory | |||||
| from extensions.ext_database import db | |||||
| from models.dataset import Dataset, DocumentSegment, Document | |||||
| from services.retrieval_service import RetrievalService | |||||
| default_retrieval_model = { | |||||
| 'search_method': 'semantic_search', | |||||
| 'reranking_enable': False, | |||||
| 'reranking_model': { | |||||
| 'reranking_provider_name': '', | |||||
| 'reranking_model_name': '' | |||||
| }, | |||||
| 'top_k': 2, | |||||
| 'score_threshold_enable': False | |||||
| } | |||||
| class DatasetMultiRetrieverToolInput(BaseModel): | |||||
| query: str = Field(..., description="dataset multi retriever and rerank") | |||||
| class DatasetMultiRetrieverTool(BaseTool): | |||||
| """Tool for querying multi dataset.""" | |||||
| name: str = "dataset-" | |||||
| args_schema: Type[BaseModel] = DatasetMultiRetrieverToolInput | |||||
| description: str = "dataset multi retriever and rerank. " | |||||
| tenant_id: str | |||||
| dataset_ids: List[str] | |||||
| top_k: int = 2 | |||||
| score_threshold: Optional[float] = None | |||||
| reranking_provider_name: str | |||||
| reranking_model_name: str | |||||
| conversation_message_task: ConversationMessageTask | |||||
| return_resource: bool | |||||
| retriever_from: str | |||||
| @classmethod | |||||
| def from_dataset(cls, dataset_ids: List[str], tenant_id: str, **kwargs): | |||||
| return cls( | |||||
| name=f'dataset-{tenant_id}', | |||||
| tenant_id=tenant_id, | |||||
| dataset_ids=dataset_ids, | |||||
| **kwargs | |||||
| ) | |||||
| def _run(self, query: str) -> str: | |||||
| threads = [] | |||||
| all_documents = [] | |||||
| for dataset_id in self.dataset_ids: | |||||
| retrieval_thread = threading.Thread(target=self._retriever, kwargs={ | |||||
| 'flask_app': current_app._get_current_object(), | |||||
| 'dataset_id': dataset_id, | |||||
| 'query': query, | |||||
| 'all_documents': all_documents | |||||
| }) | |||||
| threads.append(retrieval_thread) | |||||
| retrieval_thread.start() | |||||
| for thread in threads: | |||||
| thread.join() | |||||
| # do rerank for searched documents | |||||
| rerank = ModelFactory.get_reranking_model( | |||||
| tenant_id=self.tenant_id, | |||||
| model_provider_name=self.reranking_provider_name, | |||||
| model_name=self.reranking_model_name | |||||
| ) | |||||
| all_documents = rerank.rerank(query, all_documents, self.score_threshold, self.top_k) | |||||
| hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task) | |||||
| hit_callback.on_tool_end(all_documents) | |||||
| document_context_list = [] | |||||
| index_node_ids = [document.metadata['doc_id'] for document in all_documents] | |||||
| segments = DocumentSegment.query.filter( | |||||
| DocumentSegment.completed_at.isnot(None), | |||||
| DocumentSegment.status == 'completed', | |||||
| DocumentSegment.enabled == True, | |||||
| DocumentSegment.index_node_id.in_(index_node_ids) | |||||
| ).all() | |||||
| if segments: | |||||
| index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} | |||||
| sorted_segments = sorted(segments, | |||||
| key=lambda segment: index_node_id_to_position.get(segment.index_node_id, | |||||
| float('inf'))) | |||||
| for segment in sorted_segments: | |||||
| if segment.answer: | |||||
| document_context_list.append(f'question:{segment.content} answer:{segment.answer}') | |||||
| else: | |||||
| document_context_list.append(segment.content) | |||||
| if self.return_resource: | |||||
| context_list = [] | |||||
| resource_number = 1 | |||||
| for segment in sorted_segments: | |||||
| dataset = Dataset.query.filter_by( | |||||
| id=segment.dataset_id | |||||
| ).first() | |||||
| document = Document.query.filter(Document.id == segment.document_id, | |||||
| Document.enabled == True, | |||||
| Document.archived == False, | |||||
| ).first() | |||||
| if dataset and document: | |||||
| source = { | |||||
| 'position': resource_number, | |||||
| 'dataset_id': dataset.id, | |||||
| 'dataset_name': dataset.name, | |||||
| 'document_id': document.id, | |||||
| 'document_name': document.name, | |||||
| 'data_source_type': document.data_source_type, | |||||
| 'segment_id': segment.id, | |||||
| 'retriever_from': self.retriever_from | |||||
| } | |||||
| if self.retriever_from == 'dev': | |||||
| source['hit_count'] = segment.hit_count | |||||
| source['word_count'] = segment.word_count | |||||
| source['segment_position'] = segment.position | |||||
| source['index_node_hash'] = segment.index_node_hash | |||||
| if segment.answer: | |||||
| source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' | |||||
| else: | |||||
| source['content'] = segment.content | |||||
| context_list.append(source) | |||||
| resource_number += 1 | |||||
| hit_callback.return_retriever_resource_info(context_list) | |||||
| return str("\n".join(document_context_list)) | |||||
| async def _arun(self, tool_input: str) -> str: | |||||
| raise NotImplementedError() | |||||
| def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: List): | |||||
| with flask_app.app_context(): | |||||
| dataset = db.session.query(Dataset).filter( | |||||
| Dataset.tenant_id == self.tenant_id, | |||||
| Dataset.id == dataset_id | |||||
| ).first() | |||||
| if not dataset: | |||||
| return [] | |||||
| # get retrieval model , if the model is not setting , using default | |||||
| retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | |||||
| if dataset.indexing_technique == "economy": | |||||
| # use keyword table query | |||||
| kw_table_index = KeywordTableIndex( | |||||
| dataset=dataset, | |||||
| config=KeywordTableConfig( | |||||
| max_keywords_per_chunk=5 | |||||
| ) | |||||
| ) | |||||
| documents = kw_table_index.search(query, search_kwargs={'k': self.top_k}) | |||||
| if documents: | |||||
| all_documents.extend(documents) | |||||
| else: | |||||
| try: | |||||
| embedding_model = ModelFactory.get_embedding_model( | |||||
| tenant_id=dataset.tenant_id, | |||||
| model_provider_name=dataset.embedding_model_provider, | |||||
| model_name=dataset.embedding_model | |||||
| ) | |||||
| except LLMBadRequestError: | |||||
| return [] | |||||
| except ProviderTokenNotInitError: | |||||
| return [] | |||||
| embeddings = CacheEmbedding(embedding_model) | |||||
| documents = [] | |||||
| threads = [] | |||||
| if self.top_k > 0: | |||||
| # retrieval_model source with semantic | |||||
| if retrieval_model['search_method'] == 'semantic_search' or retrieval_model[ | |||||
| 'search_method'] == 'hybrid_search': | |||||
| embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ | |||||
| 'flask_app': current_app._get_current_object(), | |||||
| 'dataset': dataset, | |||||
| 'query': query, | |||||
| 'top_k': self.top_k, | |||||
| 'score_threshold': self.score_threshold, | |||||
| 'reranking_model': None, | |||||
| 'all_documents': documents, | |||||
| 'search_method': 'hybrid_search', | |||||
| 'embeddings': embeddings | |||||
| }) | |||||
| threads.append(embedding_thread) | |||||
| embedding_thread.start() | |||||
| # retrieval_model source with full text | |||||
| if retrieval_model['search_method'] == 'full_text_search' or retrieval_model[ | |||||
| 'search_method'] == 'hybrid_search': | |||||
| full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, | |||||
| kwargs={ | |||||
| 'flask_app': current_app._get_current_object(), | |||||
| 'dataset': dataset, | |||||
| 'query': query, | |||||
| 'search_method': 'hybrid_search', | |||||
| 'embeddings': embeddings, | |||||
| 'score_threshold': retrieval_model[ | |||||
| 'score_threshold'] if retrieval_model[ | |||||
| 'score_threshold_enable'] else None, | |||||
| 'top_k': self.top_k, | |||||
| 'reranking_model': retrieval_model[ | |||||
| 'reranking_model'] if retrieval_model[ | |||||
| 'reranking_enable'] else None, | |||||
| 'all_documents': documents | |||||
| }) | |||||
| threads.append(full_text_index_thread) | |||||
| full_text_index_thread.start() | |||||
| for thread in threads: | |||||
| thread.join() | |||||
| all_documents.extend(documents) |
| import json | import json | ||||
| from typing import Type, Optional | |||||
| import threading | |||||
| from typing import Type, Optional, List | |||||
| from flask import current_app | from flask import current_app | ||||
| from langchain.tools import BaseTool | from langchain.tools import BaseTool | ||||
| from core.model_providers.model_factory import ModelFactory | from core.model_providers.model_factory import ModelFactory | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import Dataset, DocumentSegment, Document | from models.dataset import Dataset, DocumentSegment, Document | ||||
| from services.retrieval_service import RetrievalService | |||||
| default_retrieval_model = { | |||||
| 'search_method': 'semantic_search', | |||||
| 'reranking_enable': False, | |||||
| 'reranking_model': { | |||||
| 'reranking_provider_name': '', | |||||
| 'reranking_model_name': '' | |||||
| }, | |||||
| 'top_k': 2, | |||||
| 'score_threshold_enable': False | |||||
| } | |||||
| class DatasetRetrieverToolInput(BaseModel): | class DatasetRetrieverToolInput(BaseModel): | ||||
| ).first() | ).first() | ||||
| if not dataset: | if not dataset: | ||||
| return f'[{self.name} failed to find dataset with id {self.dataset_id}.]' | |||||
| return '' | |||||
| # get retrieval model , if the model is not setting , using default | |||||
| retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | |||||
| if dataset.indexing_technique == "economy": | if dataset.indexing_technique == "economy": | ||||
| # use keyword table query | # use keyword table query | ||||
| return '' | return '' | ||||
| embeddings = CacheEmbedding(embedding_model) | embeddings = CacheEmbedding(embedding_model) | ||||
| vector_index = VectorIndex( | |||||
| dataset=dataset, | |||||
| config=current_app.config, | |||||
| embeddings=embeddings | |||||
| ) | |||||
| documents = [] | |||||
| threads = [] | |||||
| if self.top_k > 0: | if self.top_k > 0: | ||||
| documents = vector_index.search( | |||||
| query, | |||||
| search_type='similarity_score_threshold', | |||||
| search_kwargs={ | |||||
| 'k': self.top_k, | |||||
| 'score_threshold': self.score_threshold, | |||||
| 'filter': { | |||||
| 'group_id': [dataset.id] | |||||
| } | |||||
| } | |||||
| ) | |||||
| # retrieval source with semantic | |||||
| if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': | |||||
| embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ | |||||
| 'flask_app': current_app._get_current_object(), | |||||
| 'dataset': dataset, | |||||
| 'query': query, | |||||
| 'top_k': self.top_k, | |||||
| 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ | |||||
| 'score_threshold_enable'] else None, | |||||
| 'reranking_model': retrieval_model['reranking_model'] if retrieval_model[ | |||||
| 'reranking_enable'] else None, | |||||
| 'all_documents': documents, | |||||
| 'search_method': retrieval_model['search_method'], | |||||
| 'embeddings': embeddings | |||||
| }) | |||||
| threads.append(embedding_thread) | |||||
| embedding_thread.start() | |||||
| # retrieval_model source with full text | |||||
| if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': | |||||
| full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ | |||||
| 'flask_app': current_app._get_current_object(), | |||||
| 'dataset': dataset, | |||||
| 'query': query, | |||||
| 'search_method': retrieval_model['search_method'], | |||||
| 'embeddings': embeddings, | |||||
| 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ | |||||
| 'score_threshold_enable'] else None, | |||||
| 'top_k': self.top_k, | |||||
| 'reranking_model': retrieval_model['reranking_model'] if retrieval_model[ | |||||
| 'reranking_enable'] else None, | |||||
| 'all_documents': documents | |||||
| }) | |||||
| threads.append(full_text_index_thread) | |||||
| full_text_index_thread.start() | |||||
| for thread in threads: | |||||
| thread.join() | |||||
| # hybrid search: rerank after all documents have been searched | |||||
| if retrieval_model['search_method'] == 'hybrid_search': | |||||
| hybrid_rerank = ModelFactory.get_reranking_model( | |||||
| tenant_id=dataset.tenant_id, | |||||
| model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'], | |||||
| model_name=retrieval_model['reranking_model']['reranking_model_name'] | |||||
| ) | |||||
| documents = hybrid_rerank.rerank(query, documents, | |||||
| retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, | |||||
| self.top_k) | |||||
| else: | else: | ||||
| documents = [] | documents = [] | ||||
| hit_callback = DatasetIndexToolCallbackHandler(dataset.id, self.conversation_message_task) | |||||
| hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task) | |||||
| hit_callback.on_tool_end(documents) | hit_callback.on_tool_end(documents) | ||||
| document_score_list = {} | document_score_list = {} | ||||
| if dataset.indexing_technique != "economy": | if dataset.indexing_technique != "economy": |
| from core.index.vector_index.milvus import Milvus | |||||
| from core.vector_store.vector.milvus import Milvus | |||||
| class MilvusVectorStore(Milvus): | class MilvusVectorStore(Milvus): |
| from qdrant_client.http.models import Filter, PointIdsList, FilterSelector | from qdrant_client.http.models import Filter, PointIdsList, FilterSelector | ||||
| from qdrant_client.local.qdrant_local import QdrantLocal | from qdrant_client.local.qdrant_local import QdrantLocal | ||||
| from core.index.vector_index.qdrant import Qdrant | |||||
| from core.vector_store.vector.qdrant import Qdrant | |||||
| class QdrantVectorStore(Qdrant): | class QdrantVectorStore(Qdrant): | ||||
| if isinstance(self.client, QdrantLocal): | if isinstance(self.client, QdrantLocal): | ||||
| self.client = cast(QdrantLocal, self.client) | self.client = cast(QdrantLocal, self.client) | ||||
| self.client._load() | self.client._load() | ||||
| from langchain.embeddings.base import Embeddings | from langchain.embeddings.base import Embeddings | ||||
| from langchain.vectorstores import VectorStore | from langchain.vectorstores import VectorStore | ||||
| from langchain.vectorstores.utils import maximal_marginal_relevance | from langchain.vectorstores.utils import maximal_marginal_relevance | ||||
| from qdrant_client.http.models import PayloadSchemaType | |||||
| from qdrant_client.http.models import PayloadSchemaType, FilterSelector, TextIndexParams, TokenizerType, TextIndexType | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from qdrant_client import grpc # noqa | from qdrant_client import grpc # noqa | ||||
| texts, metadatas, ids, batch_size | texts, metadatas, ids, batch_size | ||||
| ): | ): | ||||
| self.client.upsert( | self.client.upsert( | ||||
| collection_name=self.collection_name, points=points, **kwargs | |||||
| collection_name=self.collection_name, points=points | |||||
| ) | ) | ||||
| added_ids.extend(batch_ids) | added_ids.extend(batch_ids) | ||||
| # if is new collection, create payload index on group_id | # if is new collection, create payload index on group_id | ||||
| if self.is_new_collection: | if self.is_new_collection: | ||||
| # create payload index | |||||
| self.client.create_payload_index(self.collection_name, self.group_payload_key, | self.client.create_payload_index(self.collection_name, self.group_payload_key, | ||||
| field_schema=PayloadSchemaType.KEYWORD, | field_schema=PayloadSchemaType.KEYWORD, | ||||
| field_type=PayloadSchemaType.KEYWORD) | field_type=PayloadSchemaType.KEYWORD) | ||||
| # creat full text index | |||||
| text_index_params = TextIndexParams( | |||||
| type=TextIndexType.TEXT, | |||||
| tokenizer=TokenizerType.MULTILINGUAL, | |||||
| min_token_len=2, | |||||
| max_token_len=20, | |||||
| lowercase=True | |||||
| ) | |||||
| self.client.create_payload_index(self.collection_name, self.content_payload_key, | |||||
| field_schema=text_index_params) | |||||
| return added_ids | return added_ids | ||||
| @sync_call_fallback | @sync_call_fallback | ||||
| limit=k, | limit=k, | ||||
| offset=offset, | offset=offset, | ||||
| with_payload=True, | with_payload=True, | ||||
| with_vectors=True, # Langchain does not expect vectors to be returned | |||||
| with_vectors=True, | |||||
| score_threshold=score_threshold, | score_threshold=score_threshold, | ||||
| consistency=consistency, | consistency=consistency, | ||||
| **kwargs, | **kwargs, | ||||
| for result in results | for result in results | ||||
| ] | ] | ||||
| def similarity_search_by_bm25( | |||||
| self, | |||||
| filter: Optional[MetadataFilter] = None, | |||||
| k: int = 4 | |||||
| ) -> List[Document]: | |||||
| """Return docs most similar by bm25. | |||||
| Args: | |||||
| embedding: Embedding vector to look up documents similar to. | |||||
| k: Number of Documents to return. Defaults to 4. | |||||
| filter: Filter by metadata. Defaults to None. | |||||
| search_params: Additional search params | |||||
| Returns: | |||||
| List of documents most similar to the query text and distance for each. | |||||
| """ | |||||
| response = self.client.scroll( | |||||
| collection_name=self.collection_name, | |||||
| scroll_filter=filter, | |||||
| limit=k, | |||||
| with_payload=True, | |||||
| with_vectors=True | |||||
| ) | |||||
| results = response[0] | |||||
| documents = [] | |||||
| for result in results: | |||||
| if result: | |||||
| documents.append(self._document_from_scored_point( | |||||
| result, self.content_payload_key, self.metadata_payload_key | |||||
| )) | |||||
| return documents | |||||
| @sync_call_fallback | @sync_call_fallback | ||||
| async def asimilarity_search_with_score_by_vector( | async def asimilarity_search_with_score_by_vector( | ||||
| self, | self, |
| """Wrapper around weaviate vector database.""" | |||||
| from __future__ import annotations | |||||
| import datetime | |||||
| from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type | |||||
| from uuid import uuid4 | |||||
| import numpy as np | |||||
| from langchain.docstore.document import Document | |||||
| from langchain.embeddings.base import Embeddings | |||||
| from langchain.utils import get_from_dict_or_env | |||||
| from langchain.vectorstores.base import VectorStore | |||||
| from langchain.vectorstores.utils import maximal_marginal_relevance | |||||
| def _default_schema(index_name: str) -> Dict: | |||||
| return { | |||||
| "class": index_name, | |||||
| "properties": [ | |||||
| { | |||||
| "name": "text", | |||||
| "dataType": ["text"], | |||||
| } | |||||
| ], | |||||
| } | |||||
| def _create_weaviate_client(**kwargs: Any) -> Any: | |||||
| client = kwargs.get("client") | |||||
| if client is not None: | |||||
| return client | |||||
| weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL") | |||||
| try: | |||||
| # the weaviate api key param should not be mandatory | |||||
| weaviate_api_key = get_from_dict_or_env( | |||||
| kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None | |||||
| ) | |||||
| except ValueError: | |||||
| weaviate_api_key = None | |||||
| try: | |||||
| import weaviate | |||||
| except ImportError: | |||||
| raise ValueError( | |||||
| "Could not import weaviate python package. " | |||||
| "Please install it with `pip install weaviate-client`" | |||||
| ) | |||||
| auth = ( | |||||
| weaviate.auth.AuthApiKey(api_key=weaviate_api_key) | |||||
| if weaviate_api_key is not None | |||||
| else None | |||||
| ) | |||||
| client = weaviate.Client(weaviate_url, auth_client_secret=auth) | |||||
| return client | |||||
| def _default_score_normalizer(val: float) -> float: | |||||
| return 1 - 1 / (1 + np.exp(val)) | |||||
| def _json_serializable(value: Any) -> Any: | |||||
| if isinstance(value, datetime.datetime): | |||||
| return value.isoformat() | |||||
| return value | |||||
| class Weaviate(VectorStore): | |||||
| """Wrapper around Weaviate vector database. | |||||
| To use, you should have the ``weaviate-client`` python package installed. | |||||
| Example: | |||||
| .. code-block:: python | |||||
| import weaviate | |||||
| from langchain.vectorstores import Weaviate | |||||
| client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...) | |||||
| weaviate = Weaviate(client, index_name, text_key) | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| client: Any, | |||||
| index_name: str, | |||||
| text_key: str, | |||||
| embedding: Optional[Embeddings] = None, | |||||
| attributes: Optional[List[str]] = None, | |||||
| relevance_score_fn: Optional[ | |||||
| Callable[[float], float] | |||||
| ] = _default_score_normalizer, | |||||
| by_text: bool = True, | |||||
| ): | |||||
| """Initialize with Weaviate client.""" | |||||
| try: | |||||
| import weaviate | |||||
| except ImportError: | |||||
| raise ValueError( | |||||
| "Could not import weaviate python package. " | |||||
| "Please install it with `pip install weaviate-client`." | |||||
| ) | |||||
| if not isinstance(client, weaviate.Client): | |||||
| raise ValueError( | |||||
| f"client should be an instance of weaviate.Client, got {type(client)}" | |||||
| ) | |||||
| self._client = client | |||||
| self._index_name = index_name | |||||
| self._embedding = embedding | |||||
| self._text_key = text_key | |||||
| self._query_attrs = [self._text_key] | |||||
| self.relevance_score_fn = relevance_score_fn | |||||
| self._by_text = by_text | |||||
| if attributes is not None: | |||||
| self._query_attrs.extend(attributes) | |||||
| @property | |||||
| def embeddings(self) -> Optional[Embeddings]: | |||||
| return self._embedding | |||||
| def _select_relevance_score_fn(self) -> Callable[[float], float]: | |||||
| return ( | |||||
| self.relevance_score_fn | |||||
| if self.relevance_score_fn | |||||
| else _default_score_normalizer | |||||
| ) | |||||
| def add_texts( | |||||
| self, | |||||
| texts: Iterable[str], | |||||
| metadatas: Optional[List[dict]] = None, | |||||
| **kwargs: Any, | |||||
| ) -> List[str]: | |||||
| """Upload texts with metadata (properties) to Weaviate.""" | |||||
| from weaviate.util import get_valid_uuid | |||||
| ids = [] | |||||
| embeddings: Optional[List[List[float]]] = None | |||||
| if self._embedding: | |||||
| if not isinstance(texts, list): | |||||
| texts = list(texts) | |||||
| embeddings = self._embedding.embed_documents(texts) | |||||
| with self._client.batch as batch: | |||||
| for i, text in enumerate(texts): | |||||
| data_properties = {self._text_key: text} | |||||
| if metadatas is not None: | |||||
| for key, val in metadatas[i].items(): | |||||
| data_properties[key] = _json_serializable(val) | |||||
| # Allow for ids (consistent w/ other methods) | |||||
| # # Or uuids (backwards compatble w/ existing arg) | |||||
| # If the UUID of one of the objects already exists | |||||
| # then the existing object will be replaced by the new object. | |||||
| _id = get_valid_uuid(uuid4()) | |||||
| if "uuids" in kwargs: | |||||
| _id = kwargs["uuids"][i] | |||||
| elif "ids" in kwargs: | |||||
| _id = kwargs["ids"][i] | |||||
| batch.add_data_object( | |||||
| data_object=data_properties, | |||||
| class_name=self._index_name, | |||||
| uuid=_id, | |||||
| vector=embeddings[i] if embeddings else None, | |||||
| ) | |||||
| ids.append(_id) | |||||
| return ids | |||||
| def similarity_search( | |||||
| self, query: str, k: int = 4, **kwargs: Any | |||||
| ) -> List[Document]: | |||||
| """Return docs most similar to query. | |||||
| Args: | |||||
| query: Text to look up documents similar to. | |||||
| k: Number of Documents to return. Defaults to 4. | |||||
| Returns: | |||||
| List of Documents most similar to the query. | |||||
| """ | |||||
| if self._by_text: | |||||
| return self.similarity_search_by_text(query, k, **kwargs) | |||||
| else: | |||||
| if self._embedding is None: | |||||
| raise ValueError( | |||||
| "_embedding cannot be None for similarity_search when " | |||||
| "_by_text=False" | |||||
| ) | |||||
| embedding = self._embedding.embed_query(query) | |||||
| return self.similarity_search_by_vector(embedding, k, **kwargs) | |||||
| def similarity_search_by_text( | |||||
| self, query: str, k: int = 4, **kwargs: Any | |||||
| ) -> List[Document]: | |||||
| """Return docs most similar to query. | |||||
| Args: | |||||
| query: Text to look up documents similar to. | |||||
| k: Number of Documents to return. Defaults to 4. | |||||
| Returns: | |||||
| List of Documents most similar to the query. | |||||
| """ | |||||
| content: Dict[str, Any] = {"concepts": [query]} | |||||
| if kwargs.get("search_distance"): | |||||
| content["certainty"] = kwargs.get("search_distance") | |||||
| query_obj = self._client.query.get(self._index_name, self._query_attrs) | |||||
| if kwargs.get("where_filter"): | |||||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||||
| if kwargs.get("additional"): | |||||
| query_obj = query_obj.with_additional(kwargs.get("additional")) | |||||
| result = query_obj.with_near_text(content).with_limit(k).do() | |||||
| if "errors" in result: | |||||
| raise ValueError(f"Error during query: {result['errors']}") | |||||
| docs = [] | |||||
| for res in result["data"]["Get"][self._index_name]: | |||||
| text = res.pop(self._text_key) | |||||
| docs.append(Document(page_content=text, metadata=res)) | |||||
| return docs | |||||
| def similarity_search_by_bm25( | |||||
| self, query: str, k: int = 4, **kwargs: Any | |||||
| ) -> List[Document]: | |||||
| """Return docs using BM25F. | |||||
| Args: | |||||
| query: Text to look up documents similar to. | |||||
| k: Number of Documents to return. Defaults to 4. | |||||
| Returns: | |||||
| List of Documents most similar to the query. | |||||
| """ | |||||
| content: Dict[str, Any] = {"concepts": [query]} | |||||
| if kwargs.get("search_distance"): | |||||
| content["certainty"] = kwargs.get("search_distance") | |||||
| query_obj = self._client.query.get(self._index_name, self._query_attrs) | |||||
| if kwargs.get("where_filter"): | |||||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||||
| if kwargs.get("additional"): | |||||
| query_obj = query_obj.with_additional(kwargs.get("additional")) | |||||
| result = query_obj.with_bm25(query=content).with_limit(k).do() | |||||
| if "errors" in result: | |||||
| raise ValueError(f"Error during query: {result['errors']}") | |||||
| docs = [] | |||||
| for res in result["data"]["Get"][self._index_name]: | |||||
| text = res.pop(self._text_key) | |||||
| docs.append(Document(page_content=text, metadata=res)) | |||||
| return docs | |||||
| def similarity_search_by_vector( | |||||
| self, embedding: List[float], k: int = 4, **kwargs: Any | |||||
| ) -> List[Document]: | |||||
| """Look up similar documents by embedding vector in Weaviate.""" | |||||
| vector = {"vector": embedding} | |||||
| query_obj = self._client.query.get(self._index_name, self._query_attrs) | |||||
| if kwargs.get("where_filter"): | |||||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||||
| if kwargs.get("additional"): | |||||
| query_obj = query_obj.with_additional(kwargs.get("additional")) | |||||
| result = query_obj.with_near_vector(vector).with_limit(k).do() | |||||
| if "errors" in result: | |||||
| raise ValueError(f"Error during query: {result['errors']}") | |||||
| docs = [] | |||||
| for res in result["data"]["Get"][self._index_name]: | |||||
| text = res.pop(self._text_key) | |||||
| docs.append(Document(page_content=text, metadata=res)) | |||||
| return docs | |||||
| def max_marginal_relevance_search( | |||||
| self, | |||||
| query: str, | |||||
| k: int = 4, | |||||
| fetch_k: int = 20, | |||||
| lambda_mult: float = 0.5, | |||||
| **kwargs: Any, | |||||
| ) -> List[Document]: | |||||
| """Return docs selected using the maximal marginal relevance. | |||||
| Maximal marginal relevance optimizes for similarity to query AND diversity | |||||
| among selected documents. | |||||
| Args: | |||||
| query: Text to look up documents similar to. | |||||
| k: Number of Documents to return. Defaults to 4. | |||||
| fetch_k: Number of Documents to fetch to pass to MMR algorithm. | |||||
| lambda_mult: Number between 0 and 1 that determines the degree | |||||
| of diversity among the results with 0 corresponding | |||||
| to maximum diversity and 1 to minimum diversity. | |||||
| Defaults to 0.5. | |||||
| Returns: | |||||
| List of Documents selected by maximal marginal relevance. | |||||
| """ | |||||
| if self._embedding is not None: | |||||
| embedding = self._embedding.embed_query(query) | |||||
| else: | |||||
| raise ValueError( | |||||
| "max_marginal_relevance_search requires a suitable Embeddings object" | |||||
| ) | |||||
| return self.max_marginal_relevance_search_by_vector( | |||||
| embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs | |||||
| ) | |||||
| def max_marginal_relevance_search_by_vector( | |||||
| self, | |||||
| embedding: List[float], | |||||
| k: int = 4, | |||||
| fetch_k: int = 20, | |||||
| lambda_mult: float = 0.5, | |||||
| **kwargs: Any, | |||||
| ) -> List[Document]: | |||||
| """Return docs selected using the maximal marginal relevance. | |||||
| Maximal marginal relevance optimizes for similarity to query AND diversity | |||||
| among selected documents. | |||||
| Args: | |||||
| embedding: Embedding to look up documents similar to. | |||||
| k: Number of Documents to return. Defaults to 4. | |||||
| fetch_k: Number of Documents to fetch to pass to MMR algorithm. | |||||
| lambda_mult: Number between 0 and 1 that determines the degree | |||||
| of diversity among the results with 0 corresponding | |||||
| to maximum diversity and 1 to minimum diversity. | |||||
| Defaults to 0.5. | |||||
| Returns: | |||||
| List of Documents selected by maximal marginal relevance. | |||||
| """ | |||||
| vector = {"vector": embedding} | |||||
| query_obj = self._client.query.get(self._index_name, self._query_attrs) | |||||
| if kwargs.get("where_filter"): | |||||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||||
| results = ( | |||||
| query_obj.with_additional("vector") | |||||
| .with_near_vector(vector) | |||||
| .with_limit(fetch_k) | |||||
| .do() | |||||
| ) | |||||
| payload = results["data"]["Get"][self._index_name] | |||||
| embeddings = [result["_additional"]["vector"] for result in payload] | |||||
| mmr_selected = maximal_marginal_relevance( | |||||
| np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult | |||||
| ) | |||||
| docs = [] | |||||
| for idx in mmr_selected: | |||||
| text = payload[idx].pop(self._text_key) | |||||
| payload[idx].pop("_additional") | |||||
| meta = payload[idx] | |||||
| docs.append(Document(page_content=text, metadata=meta)) | |||||
| return docs | |||||
| def similarity_search_with_score( | |||||
| self, query: str, k: int = 4, **kwargs: Any | |||||
| ) -> List[Tuple[Document, float]]: | |||||
| """ | |||||
| Return list of documents most similar to the query | |||||
| text and cosine distance in float for each. | |||||
| Lower score represents more similarity. | |||||
| """ | |||||
| if self._embedding is None: | |||||
| raise ValueError( | |||||
| "_embedding cannot be None for similarity_search_with_score" | |||||
| ) | |||||
| content: Dict[str, Any] = {"concepts": [query]} | |||||
| if kwargs.get("search_distance"): | |||||
| content["certainty"] = kwargs.get("search_distance") | |||||
| query_obj = self._client.query.get(self._index_name, self._query_attrs) | |||||
| embedded_query = self._embedding.embed_query(query) | |||||
| if not self._by_text: | |||||
| vector = {"vector": embedded_query} | |||||
| result = ( | |||||
| query_obj.with_near_vector(vector) | |||||
| .with_limit(k) | |||||
| .with_additional("vector") | |||||
| .do() | |||||
| ) | |||||
| else: | |||||
| result = ( | |||||
| query_obj.with_near_text(content) | |||||
| .with_limit(k) | |||||
| .with_additional("vector") | |||||
| .do() | |||||
| ) | |||||
| if "errors" in result: | |||||
| raise ValueError(f"Error during query: {result['errors']}") | |||||
| docs_and_scores = [] | |||||
| for res in result["data"]["Get"][self._index_name]: | |||||
| text = res.pop(self._text_key) | |||||
| score = np.dot(res["_additional"]["vector"], embedded_query) | |||||
| docs_and_scores.append((Document(page_content=text, metadata=res), score)) | |||||
| return docs_and_scores | |||||
| @classmethod | |||||
| def from_texts( | |||||
| cls: Type[Weaviate], | |||||
| texts: List[str], | |||||
| embedding: Embeddings, | |||||
| metadatas: Optional[List[dict]] = None, | |||||
| **kwargs: Any, | |||||
| ) -> Weaviate: | |||||
| """Construct Weaviate wrapper from raw documents. | |||||
| This is a user-friendly interface that: | |||||
| 1. Embeds documents. | |||||
| 2. Creates a new index for the embeddings in the Weaviate instance. | |||||
| 3. Adds the documents to the newly created Weaviate index. | |||||
| This is intended to be a quick way to get started. | |||||
| Example: | |||||
| .. code-block:: python | |||||
| from langchain.vectorstores.weaviate import Weaviate | |||||
| from langchain.embeddings import OpenAIEmbeddings | |||||
| embeddings = OpenAIEmbeddings() | |||||
| weaviate = Weaviate.from_texts( | |||||
| texts, | |||||
| embeddings, | |||||
| weaviate_url="http://localhost:8080" | |||||
| ) | |||||
| """ | |||||
| client = _create_weaviate_client(**kwargs) | |||||
| from weaviate.util import get_valid_uuid | |||||
| index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}") | |||||
| embeddings = embedding.embed_documents(texts) if embedding else None | |||||
| text_key = "text" | |||||
| schema = _default_schema(index_name) | |||||
| attributes = list(metadatas[0].keys()) if metadatas else None | |||||
| # check whether the index already exists | |||||
| if not client.schema.contains(schema): | |||||
| client.schema.create_class(schema) | |||||
| with client.batch as batch: | |||||
| for i, text in enumerate(texts): | |||||
| data_properties = { | |||||
| text_key: text, | |||||
| } | |||||
| if metadatas is not None: | |||||
| for key in metadatas[i].keys(): | |||||
| data_properties[key] = metadatas[i][key] | |||||
| # If the UUID of one of the objects already exists | |||||
| # then the existing objectwill be replaced by the new object. | |||||
| if "uuids" in kwargs: | |||||
| _id = kwargs["uuids"][i] | |||||
| else: | |||||
| _id = get_valid_uuid(uuid4()) | |||||
| # if an embedding strategy is not provided, we let | |||||
| # weaviate create the embedding. Note that this will only | |||||
| # work if weaviate has been installed with a vectorizer module | |||||
| # like text2vec-contextionary for example | |||||
| params = { | |||||
| "uuid": _id, | |||||
| "data_object": data_properties, | |||||
| "class_name": index_name, | |||||
| } | |||||
| if embeddings is not None: | |||||
| params["vector"] = embeddings[i] | |||||
| batch.add_data_object(**params) | |||||
| batch.flush() | |||||
| relevance_score_fn = kwargs.get("relevance_score_fn") | |||||
| by_text: bool = kwargs.get("by_text", False) | |||||
| return cls( | |||||
| client, | |||||
| index_name, | |||||
| text_key, | |||||
| embedding=embedding, | |||||
| attributes=attributes, | |||||
| relevance_score_fn=relevance_score_fn, | |||||
| by_text=by_text, | |||||
| ) | |||||
| def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None: | |||||
| """Delete by vector IDs. | |||||
| Args: | |||||
| ids: List of ids to delete. | |||||
| """ | |||||
| if ids is None: | |||||
| raise ValueError("No ids provided to delete.") | |||||
| # TODO: Check if this can be done in bulk | |||||
| for id in ids: | |||||
| self._client.data_object.delete(uuid=id) |
| 'created_at': TimestampField, | 'created_at': TimestampField, | ||||
| } | } | ||||
| reranking_model_fields = { | |||||
| 'reranking_provider_name': fields.String, | |||||
| 'reranking_model_name': fields.String | |||||
| } | |||||
| dataset_retrieval_model_fields = { | |||||
| 'search_method': fields.String, | |||||
| 'reranking_enable': fields.Boolean, | |||||
| 'reranking_model': fields.Nested(reranking_model_fields), | |||||
| 'top_k': fields.Integer, | |||||
| 'score_threshold_enable': fields.Boolean, | |||||
| 'score_threshold': fields.Float | |||||
| } | |||||
| dataset_detail_fields = { | dataset_detail_fields = { | ||||
| 'id': fields.String, | 'id': fields.String, | ||||
| 'name': fields.String, | 'name': fields.String, | ||||
| 'updated_at': TimestampField, | 'updated_at': TimestampField, | ||||
| 'embedding_model': fields.String, | 'embedding_model': fields.String, | ||||
| 'embedding_model_provider': fields.String, | 'embedding_model_provider': fields.String, | ||||
| 'embedding_available': fields.Boolean | |||||
| 'embedding_available': fields.Boolean, | |||||
| 'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields) | |||||
| } | } | ||||
| dataset_query_detail_fields = { | dataset_query_detail_fields = { | ||||
| "created_by": fields.String, | "created_by": fields.String, | ||||
| "created_at": TimestampField | "created_at": TimestampField | ||||
| } | } | ||||
| """add-dataset-retrival-model | |||||
| Revision ID: fca025d3b60f | |||||
| Revises: b3a09c049e8e | |||||
| Create Date: 2023-11-03 13:08:23.246396 | |||||
| """ | |||||
| from alembic import op | |||||
| import sqlalchemy as sa | |||||
| from sqlalchemy.dialects import postgresql | |||||
| # revision identifiers, used by Alembic. | |||||
| revision = 'fca025d3b60f' | |||||
| down_revision = '8fe468ba0ca5' | |||||
| branch_labels = None | |||||
| depends_on = None | |||||
| def upgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| op.drop_table('sessions') | |||||
| with op.batch_alter_table('datasets', schema=None) as batch_op: | |||||
| batch_op.add_column(sa.Column('retrieval_model', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) | |||||
| batch_op.create_index('retrieval_model_idx', ['retrieval_model'], unique=False, postgresql_using='gin') | |||||
| # ### end Alembic commands ### | |||||
| def downgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| with op.batch_alter_table('datasets', schema=None) as batch_op: | |||||
| batch_op.drop_index('retrieval_model_idx', postgresql_using='gin') | |||||
| batch_op.drop_column('retrieval_model') | |||||
| op.create_table('sessions', | |||||
| sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False), | |||||
| sa.Column('session_id', sa.VARCHAR(length=255), autoincrement=False, nullable=True), | |||||
| sa.Column('data', postgresql.BYTEA(), autoincrement=False, nullable=True), | |||||
| sa.Column('expiry', postgresql.TIMESTAMP(), autoincrement=False, nullable=True), | |||||
| sa.PrimaryKeyConstraint('id', name='sessions_pkey'), | |||||
| sa.UniqueConstraint('session_id', name='sessions_session_id_key') | |||||
| ) | |||||
| # ### end Alembic commands ### |
| from json import JSONDecodeError | from json import JSONDecodeError | ||||
| from sqlalchemy import func | from sqlalchemy import func | ||||
| from sqlalchemy.dialects.postgresql import UUID | |||||
| from sqlalchemy.dialects.postgresql import UUID, JSONB | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.account import Account | from models.account import Account | ||||
| __table_args__ = ( | __table_args__ = ( | ||||
| db.PrimaryKeyConstraint('id', name='dataset_pkey'), | db.PrimaryKeyConstraint('id', name='dataset_pkey'), | ||||
| db.Index('dataset_tenant_idx', 'tenant_id'), | db.Index('dataset_tenant_idx', 'tenant_id'), | ||||
| db.Index('retrieval_model_idx', "retrieval_model", postgresql_using='gin') | |||||
| ) | ) | ||||
| INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy'] | INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy'] | ||||
| embedding_model = db.Column(db.String(255), nullable=True) | embedding_model = db.Column(db.String(255), nullable=True) | ||||
| embedding_model_provider = db.Column(db.String(255), nullable=True) | embedding_model_provider = db.Column(db.String(255), nullable=True) | ||||
| collection_binding_id = db.Column(UUID, nullable=True) | collection_binding_id = db.Column(UUID, nullable=True) | ||||
| retrieval_model = db.Column(JSONB, nullable=True) | |||||
| @property | @property | ||||
| def dataset_keyword_table(self): | def dataset_keyword_table(self): | ||||
| return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \ | return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \ | ||||
| .filter(Document.dataset_id == self.id).scalar() | .filter(Document.dataset_id == self.id).scalar() | ||||
| @property | |||||
| def retrieval_model_dict(self): | |||||
| default_retrieval_model = { | |||||
| 'search_method': 'semantic_search', | |||||
| 'reranking_enable': False, | |||||
| 'reranking_model': { | |||||
| 'reranking_provider_name': '', | |||||
| 'reranking_model_name': '' | |||||
| }, | |||||
| 'top_k': 2, | |||||
| 'score_threshold_enable': False | |||||
| } | |||||
| return self.retrieval_model if self.retrieval_model else default_retrieval_model | |||||
| class DatasetProcessRule(db.Model): | class DatasetProcessRule(db.Model): | ||||
| __tablename__ = 'dataset_process_rules' | __tablename__ = 'dataset_process_rules' | ||||
| ], | ], | ||||
| 'segmentation': { | 'segmentation': { | ||||
| 'delimiter': '\n', | 'delimiter': '\n', | ||||
| 'max_tokens': 1000 | |||||
| 'max_tokens': 512 | |||||
| } | } | ||||
| } | } | ||||
| model_name = db.Column(db.String(40), nullable=False) | model_name = db.Column(db.String(40), nullable=False) | ||||
| collection_name = db.Column(db.String(64), nullable=False) | collection_name = db.Column(db.String(64), nullable=False) | ||||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | ||||
| @property | @property | ||||
| def dataset_configs_dict(self) -> dict: | def dataset_configs_dict(self) -> dict: | ||||
| return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}} | |||||
| if self.dataset_configs: | |||||
| dataset_configs = json.loads(self.dataset_configs) | |||||
| if 'retrieval_model' not in dataset_configs: | |||||
| return {'retrieval_model': 'single'} | |||||
| else: | |||||
| return dataset_configs | |||||
| return {'retrieval_model': 'single'} | |||||
| @property | @property | ||||
| def file_upload_dict(self) -> dict: | def file_upload_dict(self) -> dict: |
| tenacity==8.2.2 | tenacity==8.2.2 | ||||
| cachetools~=5.3.0 | cachetools~=5.3.0 | ||||
| weaviate-client~=3.21.0 | weaviate-client~=3.21.0 | ||||
| qdrant_client~=1.1.6 | |||||
| mailchimp-transactional~=1.0.50 | mailchimp-transactional~=1.0.50 | ||||
| scikit-learn==1.2.2 | scikit-learn==1.2.2 | ||||
| sentry-sdk[flask]~=1.21.1 | sentry-sdk[flask]~=1.21.1 | ||||
| safetensors==0.3.2 | safetensors==0.3.2 | ||||
| zhipuai==1.0.7 | zhipuai==1.0.7 | ||||
| werkzeug==2.3.7 | werkzeug==2.3.7 | ||||
| pymilvus==2.3.0 | |||||
| pymilvus==2.3.0 | |||||
| qdrant-client==1.6.4 | |||||
| cohere~=4.32 |
| # dataset_configs | # dataset_configs | ||||
| if 'dataset_configs' not in config or not config["dataset_configs"]: | if 'dataset_configs' not in config or not config["dataset_configs"]: | ||||
| config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}} | |||||
| config["dataset_configs"] = {'retrieval_model': 'single'} | |||||
| if not isinstance(config["dataset_configs"], dict): | |||||
| raise ValueError("dataset_configs must be of object type") | |||||
| if config["dataset_configs"]['retrieval_model'] == 'multiple': | |||||
| if not config["dataset_configs"]['reranking_model']: | |||||
| raise ValueError("reranking_model has not been set") | |||||
| if not isinstance(config["dataset_configs"]['reranking_model'], dict): | |||||
| raise ValueError("reranking_model must be of object type") | |||||
| if not isinstance(config["dataset_configs"], dict): | if not isinstance(config["dataset_configs"], dict): | ||||
| raise ValueError("dataset_configs must be of object type") | raise ValueError("dataset_configs must be of object type") |
| filtered_data['updated_by'] = user.id | filtered_data['updated_by'] = user.id | ||||
| filtered_data['updated_at'] = datetime.datetime.now() | filtered_data['updated_at'] = datetime.datetime.now() | ||||
| # update Retrieval model | |||||
| filtered_data['retrieval_model'] = data['retrieval_model'] | |||||
| dataset.query.filter_by(id=dataset_id).update(filtered_data) | dataset.query.filter_by(id=dataset_id).update(filtered_data) | ||||
| db.session.commit() | db.session.commit() | ||||
| embedding_model.name | embedding_model.name | ||||
| ) | ) | ||||
| dataset.collection_binding_id = dataset_collection_binding.id | dataset.collection_binding_id = dataset_collection_binding.id | ||||
| if not dataset.retrieval_model: | |||||
| default_retrieval_model = { | |||||
| 'search_method': 'semantic_search', | |||||
| 'reranking_enable': False, | |||||
| 'reranking_model': { | |||||
| 'reranking_provider_name': '', | |||||
| 'reranking_model_name': '' | |||||
| }, | |||||
| 'top_k': 2, | |||||
| 'score_threshold_enable': False | |||||
| } | |||||
| dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model | |||||
| documents = [] | documents = [] | ||||
| batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) | batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) | ||||
| raise ValueError(f"All your documents have overed limit {tenant_document_count}.") | raise ValueError(f"All your documents have overed limit {tenant_document_count}.") | ||||
| embedding_model = None | embedding_model = None | ||||
| dataset_collection_binding_id = None | dataset_collection_binding_id = None | ||||
| retrieval_model = None | |||||
| if document_data['indexing_technique'] == 'high_quality': | if document_data['indexing_technique'] == 'high_quality': | ||||
| embedding_model = ModelFactory.get_embedding_model( | embedding_model = ModelFactory.get_embedding_model( | ||||
| tenant_id=tenant_id | tenant_id=tenant_id | ||||
| embedding_model.name | embedding_model.name | ||||
| ) | ) | ||||
| dataset_collection_binding_id = dataset_collection_binding.id | dataset_collection_binding_id = dataset_collection_binding.id | ||||
| if 'retrieval_model' in document_data and document_data['retrieval_model']: | |||||
| retrieval_model = document_data['retrieval_model'] | |||||
| else: | |||||
| default_retrieval_model = { | |||||
| 'search_method': 'semantic_search', | |||||
| 'reranking_enable': False, | |||||
| 'reranking_model': { | |||||
| 'reranking_provider_name': '', | |||||
| 'reranking_model_name': '' | |||||
| }, | |||||
| 'top_k': 2, | |||||
| 'score_threshold_enable': False | |||||
| } | |||||
| retrieval_model = default_retrieval_model | |||||
| # save dataset | # save dataset | ||||
| dataset = Dataset( | dataset = Dataset( | ||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| created_by=account.id, | created_by=account.id, | ||||
| embedding_model=embedding_model.name if embedding_model else None, | embedding_model=embedding_model.name if embedding_model else None, | ||||
| embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None, | embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None, | ||||
| collection_binding_id=dataset_collection_binding_id | |||||
| collection_binding_id=dataset_collection_binding_id, | |||||
| retrieval_model=retrieval_model | |||||
| ) | ) | ||||
| db.session.add(dataset) | db.session.add(dataset) | ||||
| return dataset, documents, batch | return dataset, documents, batch | ||||
| @classmethod | @classmethod | ||||
| def document_create_args_validate(cls, args: dict): | |||||
| def document_create_args_validate(cls, args: dict): | |||||
| if 'original_document_id' not in args or not args['original_document_id']: | if 'original_document_id' not in args or not args['original_document_id']: | ||||
| DocumentService.data_source_args_validate(args) | DocumentService.data_source_args_validate(args) | ||||
| DocumentService.process_rule_args_validate(args) | DocumentService.process_rule_args_validate(args) |
| import json | |||||
| import logging | import logging | ||||
| import threading | |||||
| import time | import time | ||||
| from typing import List | from typing import List | ||||
| from sklearn.manifold import TSNE | from sklearn.manifold import TSNE | ||||
| from core.embedding.cached_embedding import CacheEmbedding | from core.embedding.cached_embedding import CacheEmbedding | ||||
| from core.index.vector_index.vector_index import VectorIndex | |||||
| from core.model_providers.model_factory import ModelFactory | from core.model_providers.model_factory import ModelFactory | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.account import Account | from models.account import Account | ||||
| from models.dataset import Dataset, DocumentSegment, DatasetQuery | from models.dataset import Dataset, DocumentSegment, DatasetQuery | ||||
| from services.retrieval_service import RetrievalService | |||||
| default_retrieval_model = { | |||||
| 'search_method': 'semantic_search', | |||||
| 'reranking_enable': False, | |||||
| 'reranking_model': { | |||||
| 'reranking_provider_name': '', | |||||
| 'reranking_model_name': '' | |||||
| }, | |||||
| 'top_k': 2, | |||||
| 'score_threshold_enable': False | |||||
| } | |||||
| class HitTestingService: | class HitTestingService: | ||||
| @classmethod | @classmethod | ||||
| def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict: | |||||
| def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict: | |||||
| if dataset.available_document_count == 0 or dataset.available_segment_count == 0: | if dataset.available_document_count == 0 or dataset.available_segment_count == 0: | ||||
| return { | return { | ||||
| "query": { | "query": { | ||||
| "records": [] | "records": [] | ||||
| } | } | ||||
| start = time.perf_counter() | |||||
| # get retrieval model , if the model is not setting , using default | |||||
| if not retrieval_model: | |||||
| retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | |||||
| # get embedding model | |||||
| embedding_model = ModelFactory.get_embedding_model( | embedding_model = ModelFactory.get_embedding_model( | ||||
| tenant_id=dataset.tenant_id, | tenant_id=dataset.tenant_id, | ||||
| model_provider_name=dataset.embedding_model_provider, | model_provider_name=dataset.embedding_model_provider, | ||||
| model_name=dataset.embedding_model | model_name=dataset.embedding_model | ||||
| ) | ) | ||||
| embeddings = CacheEmbedding(embedding_model) | embeddings = CacheEmbedding(embedding_model) | ||||
| vector_index = VectorIndex( | |||||
| dataset=dataset, | |||||
| config=current_app.config, | |||||
| embeddings=embeddings | |||||
| ) | |||||
| all_documents = [] | |||||
| threads = [] | |||||
| # retrieval_model source with semantic | |||||
| if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': | |||||
| embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ | |||||
| 'flask_app': current_app._get_current_object(), | |||||
| 'dataset': dataset, | |||||
| 'query': query, | |||||
| 'top_k': retrieval_model['top_k'], | |||||
| 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, | |||||
| 'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None, | |||||
| 'all_documents': all_documents, | |||||
| 'search_method': retrieval_model['search_method'], | |||||
| 'embeddings': embeddings | |||||
| }) | |||||
| threads.append(embedding_thread) | |||||
| embedding_thread.start() | |||||
| # retrieval source with full text | |||||
| if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': | |||||
| full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ | |||||
| 'flask_app': current_app._get_current_object(), | |||||
| 'dataset': dataset, | |||||
| 'query': query, | |||||
| 'search_method': retrieval_model['search_method'], | |||||
| 'embeddings': embeddings, | |||||
| 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, | |||||
| 'top_k': retrieval_model['top_k'], | |||||
| 'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None, | |||||
| 'all_documents': all_documents | |||||
| }) | |||||
| threads.append(full_text_index_thread) | |||||
| full_text_index_thread.start() | |||||
| for thread in threads: | |||||
| thread.join() | |||||
| if retrieval_model['search_method'] == 'hybrid_search': | |||||
| hybrid_rerank = ModelFactory.get_reranking_model( | |||||
| tenant_id=dataset.tenant_id, | |||||
| model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'], | |||||
| model_name=retrieval_model['reranking_model']['reranking_model_name'] | |||||
| ) | |||||
| all_documents = hybrid_rerank.rerank(query, all_documents, | |||||
| retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, | |||||
| retrieval_model['top_k']) | |||||
| start = time.perf_counter() | |||||
| documents = vector_index.search( | |||||
| query, | |||||
| search_type='similarity_score_threshold', | |||||
| search_kwargs={ | |||||
| 'k': 10, | |||||
| 'filter': { | |||||
| 'group_id': [dataset.id] | |||||
| } | |||||
| } | |||||
| ) | |||||
| end = time.perf_counter() | end = time.perf_counter() | ||||
| logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") | logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") | ||||
| db.session.add(dataset_query) | db.session.add(dataset_query) | ||||
| db.session.commit() | db.session.commit() | ||||
| return cls.compact_retrieve_response(dataset, embeddings, query, documents) | |||||
| return cls.compact_retrieve_response(dataset, embeddings, query, all_documents) | |||||
| @classmethod | @classmethod | ||||
| def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: List[Document]): | def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: List[Document]): | ||||
| record = { | record = { | ||||
| "segment": segment, | "segment": segment, | ||||
| "score": document.metadata['score'], | |||||
| "score": document.metadata.get('score', None), | |||||
| "tsne_position": tsne_position_data[i] | "tsne_position": tsne_position_data[i] | ||||
| } | } | ||||
| tsne_position_data.append({'x': float(data_tsne[i][0]), 'y': float(data_tsne[i][1])}) | tsne_position_data.append({'x': float(data_tsne[i][0]), 'y': float(data_tsne[i][1])}) | ||||
| return tsne_position_data | return tsne_position_data | ||||
| @classmethod | |||||
| def hit_testing_args_check(cls, args): | |||||
| query = args['query'] | |||||
| if not query or len(query) > 250: | |||||
| raise ValueError('Query is required and cannot exceed 250 characters') | |||||
| from typing import Optional | |||||
| from flask import current_app, Flask | |||||
| from langchain.embeddings.base import Embeddings | |||||
| from core.index.vector_index.vector_index import VectorIndex | |||||
| from core.model_providers.model_factory import ModelFactory | |||||
| from models.dataset import Dataset | |||||
| default_retrieval_model = { | |||||
| 'search_method': 'semantic_search', | |||||
| 'reranking_enable': False, | |||||
| 'reranking_model': { | |||||
| 'reranking_provider_name': '', | |||||
| 'reranking_model_name': '' | |||||
| }, | |||||
| 'top_k': 2, | |||||
| 'score_threshold_enable': False | |||||
| } | |||||
| class RetrievalService: | |||||
| @classmethod | |||||
| def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str, | |||||
| top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], | |||||
| all_documents: list, search_method: str, embeddings: Embeddings): | |||||
| with flask_app.app_context(): | |||||
| vector_index = VectorIndex( | |||||
| dataset=dataset, | |||||
| config=current_app.config, | |||||
| embeddings=embeddings | |||||
| ) | |||||
| documents = vector_index.search( | |||||
| query, | |||||
| search_type='similarity_score_threshold', | |||||
| search_kwargs={ | |||||
| 'k': top_k, | |||||
| 'score_threshold': score_threshold, | |||||
| 'filter': { | |||||
| 'group_id': [dataset.id] | |||||
| } | |||||
| } | |||||
| ) | |||||
| if documents: | |||||
| if reranking_model and search_method == 'semantic_search': | |||||
| rerank = ModelFactory.get_reranking_model( | |||||
| tenant_id=dataset.tenant_id, | |||||
| model_provider_name=reranking_model['reranking_provider_name'], | |||||
| model_name=reranking_model['reranking_model_name'] | |||||
| ) | |||||
| all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents))) | |||||
| else: | |||||
| all_documents.extend(documents) | |||||
| @classmethod | |||||
| def full_text_index_search(cls, flask_app: Flask, dataset: Dataset, query: str, | |||||
| top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], | |||||
| all_documents: list, search_method: str, embeddings: Embeddings): | |||||
| with flask_app.app_context(): | |||||
| vector_index = VectorIndex( | |||||
| dataset=dataset, | |||||
| config=current_app.config, | |||||
| embeddings=embeddings | |||||
| ) | |||||
| documents = vector_index.search_by_full_text_index( | |||||
| query, | |||||
| search_type='similarity_score_threshold', | |||||
| top_k=top_k | |||||
| ) | |||||
| if documents: | |||||
| if reranking_model and search_method == 'full_text_search': | |||||
| rerank = ModelFactory.get_reranking_model( | |||||
| tenant_id=dataset.tenant_id, | |||||
| model_provider_name=reranking_model['reranking_provider_name'], | |||||
| model_name=reranking_model['reranking_model_name'] | |||||
| ) | |||||
| all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents))) | |||||
| else: | |||||
| all_documents.extend(documents) | |||||