Bladeren bron

Fix/multi thread parameter (#1604)

tags/0.3.31-fix3
Jyong 1 jaar geleden
bovenliggende
commit
a5b80c9d1f
No account linked to committer's email address

+ 2
- 2
api/core/tool/dataset_multi_retriever_tool.py Bestand weergeven

'search_method'] == 'hybrid_search': 'search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'dataset': dataset,
'dataset_id': str(dataset.id),
'query': query, 'query': query,
'top_k': self.top_k, 'top_k': self.top_k,
'score_threshold': self.score_threshold, 'score_threshold': self.score_threshold,
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search,
kwargs={ kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'dataset': dataset,
'dataset_id': str(dataset.id),
'query': query, 'query': query,
'search_method': 'hybrid_search', 'search_method': 'hybrid_search',
'embeddings': embeddings, 'embeddings': embeddings,

+ 2
- 2
api/core/tool/dataset_retriever_tool.py Bestand weergeven

if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'dataset': dataset,
'dataset_id': str(dataset.id),
'query': query, 'query': query,
'top_k': self.top_k, 'top_k': self.top_k,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'dataset': dataset,
'dataset_id': str(dataset.id),
'query': query, 'query': query,
'search_method': retrieval_model['search_method'], 'search_method': retrieval_model['search_method'],
'embeddings': embeddings, 'embeddings': embeddings,

+ 2
- 2
api/services/hit_testing_service.py Bestand weergeven

if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'dataset': dataset,
'dataset_id': str(dataset.id),
'query': query, 'query': query,
'top_k': retrieval_model['top_k'], 'top_k': retrieval_model['top_k'],
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'dataset': dataset,
'dataset_id': str(dataset.id),
'query': query, 'query': query,
'search_method': retrieval_model['search_method'], 'search_method': retrieval_model['search_method'],
'embeddings': embeddings, 'embeddings': embeddings,

+ 9
- 2
api/services/retrieval_service.py Bestand weergeven

from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from core.index.vector_index.vector_index import VectorIndex from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.model_factory import ModelFactory from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset


default_retrieval_model = { default_retrieval_model = {
class RetrievalService: class RetrievalService:


@classmethod @classmethod
def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str,
def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, search_method: str, embeddings: Embeddings): all_documents: list, search_method: str, embeddings: Embeddings):
with flask_app.app_context(): with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()


vector_index = VectorIndex( vector_index = VectorIndex(
dataset=dataset, dataset=dataset,
all_documents.extend(documents) all_documents.extend(documents)


@classmethod @classmethod
def full_text_index_search(cls, flask_app: Flask, dataset: Dataset, query: str,
def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, search_method: str, embeddings: Embeddings): all_documents: list, search_method: str, embeddings: Embeddings):
with flask_app.app_context(): with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()


vector_index = VectorIndex( vector_index = VectorIndex(
dataset=dataset, dataset=dataset,

Laden…
Annuleren
Opslaan