| @@ -545,15 +545,15 @@ class DatasetRetrievalSettingApi(Resource): | |||
| case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE: | |||
| return { | |||
| 'retrieval_method': [ | |||
| RetrievalMethod.SEMANTIC_SEARCH | |||
| RetrievalMethod.SEMANTIC_SEARCH.value | |||
| ] | |||
| } | |||
| case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE: | |||
| return { | |||
| 'retrieval_method': [ | |||
| RetrievalMethod.SEMANTIC_SEARCH, | |||
| RetrievalMethod.FULL_TEXT_SEARCH, | |||
| RetrievalMethod.HYBRID_SEARCH, | |||
| RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| RetrievalMethod.FULL_TEXT_SEARCH.value, | |||
| RetrievalMethod.HYBRID_SEARCH.value, | |||
| ] | |||
| } | |||
| case _: | |||
| @@ -569,15 +569,15 @@ class DatasetRetrievalSettingMockApi(Resource): | |||
| case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE: | |||
| return { | |||
| 'retrieval_method': [ | |||
| RetrievalMethod.SEMANTIC_SEARCH | |||
| RetrievalMethod.SEMANTIC_SEARCH.value | |||
| ] | |||
| } | |||
| case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE: | |||
| return { | |||
| 'retrieval_method': [ | |||
| RetrievalMethod.SEMANTIC_SEARCH, | |||
| RetrievalMethod.FULL_TEXT_SEARCH, | |||
| RetrievalMethod.HYBRID_SEARCH, | |||
| RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| RetrievalMethod.FULL_TEXT_SEARCH.value, | |||
| RetrievalMethod.HYBRID_SEARCH.value, | |||
| ] | |||
| } | |||
| case _: | |||
| @@ -11,7 +11,7 @@ from extensions.ext_database import db | |||
| from models.dataset import Dataset | |||
| default_retrieval_model = { | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH, | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| @@ -86,7 +86,7 @@ class RetrievalService: | |||
| exception_message = ';\n'.join(exceptions) | |||
| raise Exception(exception_message) | |||
| if retrival_method == RetrievalMethod.HYBRID_SEARCH: | |||
| if retrival_method == RetrievalMethod.HYBRID_SEARCH.value: | |||
| data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) | |||
| all_documents = data_post_processor.invoke( | |||
| query=query, | |||
| @@ -142,7 +142,7 @@ class RetrievalService: | |||
| ) | |||
| if documents: | |||
| if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH: | |||
| if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value: | |||
| data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) | |||
| all_documents.extend(data_post_processor.invoke( | |||
| query=query, | |||
| @@ -174,7 +174,7 @@ class RetrievalService: | |||
| top_k=top_k | |||
| ) | |||
| if documents: | |||
| if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH: | |||
| if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value: | |||
| data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) | |||
| all_documents.extend(data_post_processor.invoke( | |||
| query=query, | |||
| @@ -28,7 +28,7 @@ from models.dataset import Dataset, DatasetQuery, DocumentSegment | |||
| from models.dataset import Document as DatasetDocument | |||
| default_retrieval_model = { | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH, | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| @@ -464,7 +464,7 @@ class DatasetRetrieval: | |||
| if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: | |||
| # get retrieval model config | |||
| default_retrieval_model = { | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH, | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| @@ -1,15 +1,15 @@ | |||
| from enum import Enum | |||
| class RetrievalMethod(str, Enum): | |||
| class RetrievalMethod(Enum): | |||
| SEMANTIC_SEARCH = 'semantic_search' | |||
| FULL_TEXT_SEARCH = 'full_text_search' | |||
| HYBRID_SEARCH = 'hybrid_search' | |||
| @staticmethod | |||
| def is_support_semantic_search(retrieval_method: str) -> bool: | |||
| return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH, RetrievalMethod.HYBRID_SEARCH} | |||
| return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value} | |||
| @staticmethod | |||
| def is_support_fulltext_search(retrieval_method: str) -> bool: | |||
| return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH, RetrievalMethod.HYBRID_SEARCH} | |||
| return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value} | |||
| @@ -14,7 +14,7 @@ from extensions.ext_database import db | |||
| from models.dataset import Dataset, Document, DocumentSegment | |||
| default_retrieval_model = { | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH, | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| @@ -8,7 +8,7 @@ from extensions.ext_database import db | |||
| from models.dataset import Dataset, Document, DocumentSegment | |||
| default_retrieval_model = { | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH, | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| @@ -22,7 +22,7 @@ from models.dataset import Dataset, Document, DocumentSegment | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| default_retrieval_model = { | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH, | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| @@ -117,7 +117,7 @@ class Dataset(db.Model): | |||
| @property | |||
| def retrieval_model_dict(self): | |||
| default_retrieval_model = { | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH, | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| @@ -688,7 +688,7 @@ class DocumentService: | |||
| dataset.collection_binding_id = dataset_collection_binding.id | |||
| if not dataset.retrieval_model: | |||
| default_retrieval_model = { | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH, | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| @@ -1059,7 +1059,7 @@ class DocumentService: | |||
| retrieval_model = document_data['retrieval_model'] | |||
| else: | |||
| default_retrieval_model = { | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH, | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| @@ -9,7 +9,7 @@ from models.account import Account | |||
| from models.dataset import Dataset, DatasetQuery, DocumentSegment | |||
| default_retrieval_model = { | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH, | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||