| @@ -192,7 +192,7 @@ class DatasetMultiRetrieverTool(BaseTool): | |||
| 'search_method'] == 'hybrid_search': | |||
| embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset': dataset, | |||
| 'dataset_id': str(dataset.id), | |||
| 'query': query, | |||
| 'top_k': self.top_k, | |||
| 'score_threshold': self.score_threshold, | |||
| @@ -210,7 +210,7 @@ class DatasetMultiRetrieverTool(BaseTool): | |||
| full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, | |||
| kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset': dataset, | |||
| 'dataset_id': str(dataset.id), | |||
| 'query': query, | |||
| 'search_method': 'hybrid_search', | |||
| 'embeddings': embeddings, | |||
| @@ -106,7 +106,7 @@ class DatasetRetrieverTool(BaseTool): | |||
| if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': | |||
| embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset': dataset, | |||
| 'dataset_id': str(dataset.id), | |||
| 'query': query, | |||
| 'top_k': self.top_k, | |||
| 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ | |||
| @@ -124,7 +124,7 @@ class DatasetRetrieverTool(BaseTool): | |||
| if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': | |||
| full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset': dataset, | |||
| 'dataset_id': str(dataset.id), | |||
| 'query': query, | |||
| 'search_method': retrieval_model['search_method'], | |||
| 'embeddings': embeddings, | |||
| @@ -61,7 +61,7 @@ class HitTestingService: | |||
| if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': | |||
| embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset': dataset, | |||
| 'dataset_id': str(dataset.id), | |||
| 'query': query, | |||
| 'top_k': retrieval_model['top_k'], | |||
| 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, | |||
| @@ -77,7 +77,7 @@ class HitTestingService: | |||
| if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': | |||
| full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'dataset': dataset, | |||
| 'dataset_id': str(dataset.id), | |||
| 'query': query, | |||
| 'search_method': retrieval_model['search_method'], | |||
| 'embeddings': embeddings, | |||
| @@ -4,6 +4,7 @@ from flask import current_app, Flask | |||
| from langchain.embeddings.base import Embeddings | |||
| from core.index.vector_index.vector_index import VectorIndex | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset | |||
| default_retrieval_model = { | |||
| @@ -21,10 +22,13 @@ default_retrieval_model = { | |||
| class RetrievalService: | |||
| @classmethod | |||
| def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str, | |||
| def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, | |||
| top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], | |||
| all_documents: list, search_method: str, embeddings: Embeddings): | |||
| with flask_app.app_context(): | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| vector_index = VectorIndex( | |||
| dataset=dataset, | |||
| @@ -56,10 +60,13 @@ class RetrievalService: | |||
| all_documents.extend(documents) | |||
| @classmethod | |||
| def full_text_index_search(cls, flask_app: Flask, dataset: Dataset, query: str, | |||
| def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, | |||
| top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], | |||
| all_documents: list, search_method: str, embeddings: Embeddings): | |||
| with flask_app.app_context(): | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| vector_index = VectorIndex( | |||
| dataset=dataset, | |||