| import logging | |||||
| from flask_restful import Resource | |||||
| from flask_login import current_user | |||||
| from flask_restful import Resource, marshal, reqparse | |||||
| from werkzeug.exceptions import Forbidden, InternalServerError, NotFound | |||||
| import services | |||||
| from controllers.console import api | from controllers.console import api | ||||
| from controllers.console.app.error import ( | |||||
| CompletionRequestError, | |||||
| ProviderModelCurrentlyNotSupportError, | |||||
| ProviderNotInitializeError, | |||||
| ProviderQuotaExceededError, | |||||
| ) | |||||
| from controllers.console.datasets.error import DatasetNotInitializedError | |||||
| from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase | |||||
| from controllers.console.setup import setup_required | from controllers.console.setup import setup_required | ||||
| from controllers.console.wraps import account_initialization_required | from controllers.console.wraps import account_initialization_required | ||||
| from core.errors.error import ( | |||||
| LLMBadRequestError, | |||||
| ModelCurrentlyNotSupportError, | |||||
| ProviderTokenNotInitError, | |||||
| QuotaExceededError, | |||||
| ) | |||||
| from core.model_runtime.errors.invoke import InvokeError | |||||
| from fields.hit_testing_fields import hit_testing_record_fields | |||||
| from libs.login import login_required | from libs.login import login_required | ||||
| from services.dataset_service import DatasetService | |||||
| from services.hit_testing_service import HitTestingService | |||||
| class HitTestingApi(Resource): | |||||
| class HitTestingApi(Resource, DatasetsHitTestingBase): | |||||
| @setup_required | @setup_required | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def post(self, dataset_id): | def post(self, dataset_id): | ||||
| dataset_id_str = str(dataset_id) | dataset_id_str = str(dataset_id) | ||||
| dataset = DatasetService.get_dataset(dataset_id_str) | |||||
| if dataset is None: | |||||
| raise NotFound("Dataset not found.") | |||||
| try: | |||||
| DatasetService.check_dataset_permission(dataset, current_user) | |||||
| except services.errors.account.NoPermissionError as e: | |||||
| raise Forbidden(str(e)) | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("query", type=str, location="json") | |||||
| parser.add_argument("retrieval_model", type=dict, required=False, location="json") | |||||
| parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") | |||||
| args = parser.parse_args() | |||||
| HitTestingService.hit_testing_args_check(args) | |||||
| try: | |||||
| response = HitTestingService.retrieve( | |||||
| dataset=dataset, | |||||
| query=args["query"], | |||||
| account=current_user, | |||||
| retrieval_model=args["retrieval_model"], | |||||
| external_retrieval_model=args["external_retrieval_model"], | |||||
| limit=10, | |||||
| ) | |||||
| dataset = self.get_and_validate_dataset(dataset_id_str) | |||||
| args = self.parse_args() | |||||
| self.hit_testing_args_check(args) | |||||
| return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} | |||||
| except services.errors.index.IndexNotInitializedError: | |||||
| raise DatasetNotInitializedError() | |||||
| except ProviderTokenNotInitError as ex: | |||||
| raise ProviderNotInitializeError(ex.description) | |||||
| except QuotaExceededError: | |||||
| raise ProviderQuotaExceededError() | |||||
| except ModelCurrentlyNotSupportError: | |||||
| raise ProviderModelCurrentlyNotSupportError() | |||||
| except LLMBadRequestError: | |||||
| raise ProviderNotInitializeError( | |||||
| "No Embedding Model or Reranking Model available. Please configure a valid provider " | |||||
| "in the Settings -> Model Provider." | |||||
| ) | |||||
| except InvokeError as e: | |||||
| raise CompletionRequestError(e.description) | |||||
| except ValueError as e: | |||||
| raise ValueError(str(e)) | |||||
| except Exception as e: | |||||
| logging.exception("Hit testing failed.") | |||||
| raise InternalServerError(str(e)) | |||||
| return self.perform_hit_testing(dataset, args) | |||||
| api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing") | api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing") |
| import logging | |||||
| from flask_login import current_user | |||||
| from flask_restful import marshal, reqparse | |||||
| from werkzeug.exceptions import Forbidden, InternalServerError, NotFound | |||||
| import services.dataset_service | |||||
| from controllers.console.app.error import ( | |||||
| CompletionRequestError, | |||||
| ProviderModelCurrentlyNotSupportError, | |||||
| ProviderNotInitializeError, | |||||
| ProviderQuotaExceededError, | |||||
| ) | |||||
| from controllers.console.datasets.error import DatasetNotInitializedError | |||||
| from core.errors.error import ( | |||||
| LLMBadRequestError, | |||||
| ModelCurrentlyNotSupportError, | |||||
| ProviderTokenNotInitError, | |||||
| QuotaExceededError, | |||||
| ) | |||||
| from core.model_runtime.errors.invoke import InvokeError | |||||
| from fields.hit_testing_fields import hit_testing_record_fields | |||||
| from services.dataset_service import DatasetService | |||||
| from services.hit_testing_service import HitTestingService | |||||
| class DatasetsHitTestingBase: | |||||
| @staticmethod | |||||
| def get_and_validate_dataset(dataset_id: str): | |||||
| dataset = DatasetService.get_dataset(dataset_id) | |||||
| if dataset is None: | |||||
| raise NotFound("Dataset not found.") | |||||
| try: | |||||
| DatasetService.check_dataset_permission(dataset, current_user) | |||||
| except services.errors.account.NoPermissionError as e: | |||||
| raise Forbidden(str(e)) | |||||
| return dataset | |||||
| @staticmethod | |||||
| def hit_testing_args_check(args): | |||||
| HitTestingService.hit_testing_args_check(args) | |||||
| @staticmethod | |||||
| def parse_args(): | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("query", type=str, location="json") | |||||
| parser.add_argument("retrieval_model", type=dict, required=False, location="json") | |||||
| parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") | |||||
| return parser.parse_args() | |||||
| @staticmethod | |||||
| def perform_hit_testing(dataset, args): | |||||
| try: | |||||
| response = HitTestingService.retrieve( | |||||
| dataset=dataset, | |||||
| query=args["query"], | |||||
| account=current_user, | |||||
| retrieval_model=args["retrieval_model"], | |||||
| external_retrieval_model=args["external_retrieval_model"], | |||||
| limit=10, | |||||
| ) | |||||
| return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} | |||||
| except services.errors.index.IndexNotInitializedError: | |||||
| raise DatasetNotInitializedError() | |||||
| except ProviderTokenNotInitError as ex: | |||||
| raise ProviderNotInitializeError(ex.description) | |||||
| except QuotaExceededError: | |||||
| raise ProviderQuotaExceededError() | |||||
| except ModelCurrentlyNotSupportError: | |||||
| raise ProviderModelCurrentlyNotSupportError() | |||||
| except LLMBadRequestError: | |||||
| raise ProviderNotInitializeError( | |||||
| "No Embedding Model or Reranking Model available. Please configure a valid provider " | |||||
| "in the Settings -> Model Provider." | |||||
| ) | |||||
| except InvokeError as e: | |||||
| raise CompletionRequestError(e.description) | |||||
| except ValueError as e: | |||||
| raise ValueError(str(e)) | |||||
| except Exception as e: | |||||
| logging.exception("Hit testing failed.") | |||||
| raise InternalServerError(str(e)) |
| bp = Blueprint("service_api", __name__, url_prefix="/v1") | bp = Blueprint("service_api", __name__, url_prefix="/v1") | ||||
| api = ExternalApi(bp) | api = ExternalApi(bp) | ||||
| from . import index | from . import index | ||||
| from .app import app, audio, completion, conversation, file, message, workflow | from .app import app, audio, completion, conversation, file, message, workflow | ||||
| from .dataset import dataset, document, segment | |||||
| from .dataset import dataset, document, hit_testing, segment |
| from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase | |||||
| from controllers.service_api import api | |||||
| from controllers.service_api.wraps import DatasetApiResource | |||||
| class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): | |||||
| def post(self, tenant_id, dataset_id): | |||||
| dataset_id_str = str(dataset_id) | |||||
| dataset = self.get_and_validate_dataset(dataset_id_str) | |||||
| args = self.parse_args() | |||||
| self.hit_testing_args_check(args) | |||||
| return self.perform_hit_testing(dataset, args) | |||||
| api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing") |
| --- | --- | ||||
| <Heading | |||||
| url='/datasets/{dataset_id}/hit_testing' | |||||
| method='POST' | |||||
| title='Dataset hit testing' | |||||
| name='#dataset_hit_testing' | |||||
| /> | |||||
| <Row> | |||||
| <Col> | |||||
| ### Path | |||||
| <Properties> | |||||
| <Property name='dataset_id' type='string' key='dataset_id'> | |||||
| Dataset ID | |||||
| </Property> | |||||
| </Properties> | |||||
| ### Request Body | |||||
| <Properties> | |||||
| <Property name='query' type='string' key='query'> | |||||
| retrieval keywordc | |||||
| </Property> | |||||
| <Property name='retrieval_model' type='object' key='retrieval_model'> | |||||
| retrieval keyword(Optional, if not filled, it will be recalled according to the default method) | |||||
| - <code>search_method</code> (text) Search method: One of the following four keywords is required | |||||
| - <code>keyword_search</code> Keyword search | |||||
| - <code>semantic_search</code> Semantic search | |||||
| - <code>full_text_search</code> Full-text search | |||||
| - <code>hybrid_search</code> Hybrid search | |||||
| - <code>reranking_enable</code> (bool) Whether to enable reranking, optional, required if the search mode is semantic_search or hybrid_search | |||||
| - <code>reranking_mode</code> (object) Rerank model configuration, optional, required if reranking is enabled | |||||
| - <code>reranking_provider_name</code> (string) Rerank model provider | |||||
| - <code>reranking_model_name</code> (string) Rerank model name | |||||
| - <code>weights</code> (double) Semantic search weight setting in hybrid search mode | |||||
| - <code>top_k</code> (integer) Number of results to return, optional | |||||
| - <code>score_threshold_enabled</code> (bool) Whether to enable score threshold | |||||
| - <code>score_threshold</code> (double) Score threshold | |||||
| </Property> | |||||
| <Property name='external_retrieval_model' type='object' key='external_retrieval_model'> | |||||
| Unused field | |||||
| </Property> | |||||
| </Properties> | |||||
| </Col> | |||||
| <Col sticky> | |||||
| <CodeGroup | |||||
| title="Request" | |||||
| tag="POST" | |||||
| label="/datasets/{dataset_id}/hit_testing" | |||||
| targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{ | |||||
| "query": "test", | |||||
| "retrieval_model": { | |||||
| "search_method": "keyword_search", | |||||
| "reranking_enable": false, | |||||
| "reranking_mode": null, | |||||
| "reranking_model": { | |||||
| "reranking_provider_name": "", | |||||
| "reranking_model_name": "" | |||||
| }, | |||||
| "weights": null, | |||||
| "top_k": 1, | |||||
| "score_threshold_enabled": false, | |||||
| "score_threshold": null | |||||
| } | |||||
| }'`} | |||||
| > | |||||
| ```bash {{ title: 'cURL' }} | |||||
| curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \ | |||||
| --header 'Authorization: Bearer {api_key}' \ | |||||
| --header 'Content-Type: application/json' \ | |||||
| --data-raw '{ | |||||
| "query": "test", | |||||
| "retrieval_model": { | |||||
| "search_method": "keyword_search", | |||||
| "reranking_enable": false, | |||||
| "reranking_mode": null, | |||||
| "reranking_model": { | |||||
| "reranking_provider_name": "", | |||||
| "reranking_model_name": "" | |||||
| }, | |||||
| "weights": null, | |||||
| "top_k": 2, | |||||
| "score_threshold_enabled": false, | |||||
| "score_threshold": null | |||||
| } | |||||
| }' | |||||
| ``` | |||||
| </CodeGroup> | |||||
| <CodeGroup title="Response"> | |||||
| ```json {{ title: 'Response' }} | |||||
| { | |||||
| "query": { | |||||
| "content": "test" | |||||
| }, | |||||
| "records": [ | |||||
| { | |||||
| "segment": { | |||||
| "id": "7fa6f24f-8679-48b3-bc9d-bdf28d73f218", | |||||
| "position": 1, | |||||
| "document_id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2", | |||||
| "content": "Operation guide", | |||||
| "answer": null, | |||||
| "word_count": 847, | |||||
| "tokens": 280, | |||||
| "keywords": [ | |||||
| "install", | |||||
| "java", | |||||
| "base", | |||||
| "scripts", | |||||
| "jdk", | |||||
| "manual", | |||||
| "internal", | |||||
| "opens", | |||||
| "add", | |||||
| "vmoptions" | |||||
| ], | |||||
| "index_node_id": "39dd8443-d960-45a8-bb46-7275ad7fbc8e", | |||||
| "index_node_hash": "0189157697b3c6a418ccf8264a09699f25858975578f3467c76d6bfc94df1d73", | |||||
| "hit_count": 0, | |||||
| "enabled": true, | |||||
| "disabled_at": null, | |||||
| "disabled_by": null, | |||||
| "status": "completed", | |||||
| "created_by": "dbcb1ab5-90c8-41a7-8b78-73b235eb6f6f", | |||||
| "created_at": 1728734540, | |||||
| "indexing_at": 1728734552, | |||||
| "completed_at": 1728734584, | |||||
| "error": null, | |||||
| "stopped_at": null, | |||||
| "document": { | |||||
| "id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2", | |||||
| "data_source_type": "upload_file", | |||||
| "name": "readme.txt", | |||||
| "doc_type": null | |||||
| } | |||||
| }, | |||||
| "score": 3.730463140527718e-05, | |||||
| "tsne_position": null | |||||
| } | |||||
| ] | |||||
| } | |||||
| ``` | |||||
| </CodeGroup> | |||||
| </Col> | |||||
| </Row> | |||||
| --- | |||||
| <Row> | <Row> | ||||
| <Col> | <Col> | ||||
| ### Error message | ### Error message |
| </Col> | </Col> | ||||
| </Row> | </Row> | ||||
| --- | |||||
| <Heading | |||||
| url='/datasets/{dataset_id}/hit_testing' | |||||
| method='POST' | |||||
| title='知识库召回测试' | |||||
| name='#dataset_hit_testing' | |||||
| /> | |||||
| <Row> | |||||
| <Col> | |||||
| ### Path | |||||
| <Properties> | |||||
| <Property name='dataset_id' type='string' key='dataset_id'> | |||||
| 知识库 ID | |||||
| </Property> | |||||
| </Properties> | |||||
| ### Request Body | |||||
| <Properties> | |||||
| <Property name='query' type='string' key='query'> | |||||
| 召回关键词 | |||||
| </Property> | |||||
| <Property name='retrieval_model' type='object' key='retrieval_model'> | |||||
| 召回参数(选填,如不填,按照默认方式召回) | |||||
| - <code>search_method</code> (text) 检索方法:以下三个关键字之一,必填 | |||||
| - <code>keyword_search</code> 关键字检索 | |||||
| - <code>semantic_search</code> 语义检索 | |||||
| - <code>full_text_search</code> 全文检索 | |||||
| - <code>hybrid_search</code> 混合检索 | |||||
| - <code>reranking_enable</code> (bool) 是否启用 Reranking,非必填,如果检索模式为semantic_search模式或者hybrid_search则传值 | |||||
| - <code>reranking_mode</code> (object) Rerank模型配置,非必填,如果启用了 reranking 则传值 | |||||
| - <code>reranking_provider_name</code> (string) Rerank 模型提供商 | |||||
| - <code>reranking_model_name</code> (string) Rerank 模型名称 | |||||
| - <code>weights</code> (double) 混合检索模式下语意检索的权重设置 | |||||
| - <code>top_k</code> (integer) 返回结果数量,非必填 | |||||
| - <code>score_threshold_enabled</code> (bool) 是否开启Score阈值 | |||||
| - <code>score_threshold</code> (double) Score阈值 | |||||
| </Property> | |||||
| <Property name='external_retrieval_model' type='object' key='external_retrieval_model'> | |||||
| 未启用字段 | |||||
| </Property> | |||||
| </Properties> | |||||
| </Col> | |||||
| <Col sticky> | |||||
| <CodeGroup | |||||
| title="Request" | |||||
| tag="POST" | |||||
| label="/datasets/{dataset_id}/hit_testing" | |||||
| targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{ | |||||
| "query": "test", | |||||
| "retrieval_model": { | |||||
| "search_method": "keyword_search", | |||||
| "reranking_enable": false, | |||||
| "reranking_mode": null, | |||||
| "reranking_model": { | |||||
| "reranking_provider_name": "", | |||||
| "reranking_model_name": "" | |||||
| }, | |||||
| "weights": null, | |||||
| "top_k": 1, | |||||
| "score_threshold_enabled": false, | |||||
| "score_threshold": null | |||||
| } | |||||
| }'`} | |||||
| > | |||||
| ```bash {{ title: 'cURL' }} | |||||
| curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \ | |||||
| --header 'Authorization: Bearer {api_key}' \ | |||||
| --header 'Content-Type: application/json' \ | |||||
| --data-raw '{ | |||||
| "query": "test", | |||||
| "retrieval_model": { | |||||
| "search_method": "keyword_search", | |||||
| "reranking_enable": false, | |||||
| "reranking_mode": null, | |||||
| "reranking_model": { | |||||
| "reranking_provider_name": "", | |||||
| "reranking_model_name": "" | |||||
| }, | |||||
| "weights": null, | |||||
| "top_k": 2, | |||||
| "score_threshold_enabled": false, | |||||
| "score_threshold": null | |||||
| } | |||||
| }' | |||||
| ``` | |||||
| </CodeGroup> | |||||
| <CodeGroup title="Response"> | |||||
| ```json {{ title: 'Response' }} | |||||
| { | |||||
| "query": { | |||||
| "content": "test" | |||||
| }, | |||||
| "records": [ | |||||
| { | |||||
| "segment": { | |||||
| "id": "7fa6f24f-8679-48b3-bc9d-bdf28d73f218", | |||||
| "position": 1, | |||||
| "document_id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2", | |||||
| "content": "Operation guide", | |||||
| "answer": null, | |||||
| "word_count": 847, | |||||
| "tokens": 280, | |||||
| "keywords": [ | |||||
| "install", | |||||
| "java", | |||||
| "base", | |||||
| "scripts", | |||||
| "jdk", | |||||
| "manual", | |||||
| "internal", | |||||
| "opens", | |||||
| "add", | |||||
| "vmoptions" | |||||
| ], | |||||
| "index_node_id": "39dd8443-d960-45a8-bb46-7275ad7fbc8e", | |||||
| "index_node_hash": "0189157697b3c6a418ccf8264a09699f25858975578f3467c76d6bfc94df1d73", | |||||
| "hit_count": 0, | |||||
| "enabled": true, | |||||
| "disabled_at": null, | |||||
| "disabled_by": null, | |||||
| "status": "completed", | |||||
| "created_by": "dbcb1ab5-90c8-41a7-8b78-73b235eb6f6f", | |||||
| "created_at": 1728734540, | |||||
| "indexing_at": 1728734552, | |||||
| "completed_at": 1728734584, | |||||
| "error": null, | |||||
| "stopped_at": null, | |||||
| "document": { | |||||
| "id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2", | |||||
| "data_source_type": "upload_file", | |||||
| "name": "readme.txt", | |||||
| "doc_type": null | |||||
| } | |||||
| }, | |||||
| "score": 3.730463140527718e-05, | |||||
| "tsne_position": null | |||||
| } | |||||
| ] | |||||
| } | |||||
| ``` | |||||
| </CodeGroup> | |||||
| </Col> | |||||
| </Row> | |||||
| --- | --- | ||||
| <Row> | <Row> |