| @@ -1,88 +1,24 @@ | |||
| import logging | |||
| from flask_restful import Resource | |||
| from flask_login import current_user | |||
| from flask_restful import Resource, marshal, reqparse | |||
| from werkzeug.exceptions import Forbidden, InternalServerError, NotFound | |||
| import services | |||
| from controllers.console import api | |||
| from controllers.console.app.error import ( | |||
| CompletionRequestError, | |||
| ProviderModelCurrentlyNotSupportError, | |||
| ProviderNotInitializeError, | |||
| ProviderQuotaExceededError, | |||
| ) | |||
| from controllers.console.datasets.error import DatasetNotInitializedError | |||
| from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.errors.error import ( | |||
| LLMBadRequestError, | |||
| ModelCurrentlyNotSupportError, | |||
| ProviderTokenNotInitError, | |||
| QuotaExceededError, | |||
| ) | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from fields.hit_testing_fields import hit_testing_record_fields | |||
| from libs.login import login_required | |||
| from services.dataset_service import DatasetService | |||
| from services.hit_testing_service import HitTestingService | |||
| class HitTestingApi(Resource): | |||
| class HitTestingApi(Resource, DatasetsHitTestingBase): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, dataset_id): | |||
| dataset_id_str = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id_str) | |||
| if dataset is None: | |||
| raise NotFound("Dataset not found.") | |||
| try: | |||
| DatasetService.check_dataset_permission(dataset, current_user) | |||
| except services.errors.account.NoPermissionError as e: | |||
| raise Forbidden(str(e)) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("query", type=str, location="json") | |||
| parser.add_argument("retrieval_model", type=dict, required=False, location="json") | |||
| parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") | |||
| args = parser.parse_args() | |||
| HitTestingService.hit_testing_args_check(args) | |||
| try: | |||
| response = HitTestingService.retrieve( | |||
| dataset=dataset, | |||
| query=args["query"], | |||
| account=current_user, | |||
| retrieval_model=args["retrieval_model"], | |||
| external_retrieval_model=args["external_retrieval_model"], | |||
| limit=10, | |||
| ) | |||
| dataset = self.get_and_validate_dataset(dataset_id_str) | |||
| args = self.parse_args() | |||
| self.hit_testing_args_check(args) | |||
| return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} | |||
| except services.errors.index.IndexNotInitializedError: | |||
| raise DatasetNotInitializedError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| "No Embedding Model or Reranking Model available. Please configure a valid provider " | |||
| "in the Settings -> Model Provider." | |||
| ) | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(e.description) | |||
| except ValueError as e: | |||
| raise ValueError(str(e)) | |||
| except Exception as e: | |||
| logging.exception("Hit testing failed.") | |||
| raise InternalServerError(str(e)) | |||
| return self.perform_hit_testing(dataset, args) | |||
| api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing") | |||
| @@ -0,0 +1,85 @@ | |||
| import logging | |||
| from flask_login import current_user | |||
| from flask_restful import marshal, reqparse | |||
| from werkzeug.exceptions import Forbidden, InternalServerError, NotFound | |||
| import services.dataset_service | |||
| from controllers.console.app.error import ( | |||
| CompletionRequestError, | |||
| ProviderModelCurrentlyNotSupportError, | |||
| ProviderNotInitializeError, | |||
| ProviderQuotaExceededError, | |||
| ) | |||
| from controllers.console.datasets.error import DatasetNotInitializedError | |||
| from core.errors.error import ( | |||
| LLMBadRequestError, | |||
| ModelCurrentlyNotSupportError, | |||
| ProviderTokenNotInitError, | |||
| QuotaExceededError, | |||
| ) | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from fields.hit_testing_fields import hit_testing_record_fields | |||
| from services.dataset_service import DatasetService | |||
| from services.hit_testing_service import HitTestingService | |||
| class DatasetsHitTestingBase: | |||
| @staticmethod | |||
| def get_and_validate_dataset(dataset_id: str): | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if dataset is None: | |||
| raise NotFound("Dataset not found.") | |||
| try: | |||
| DatasetService.check_dataset_permission(dataset, current_user) | |||
| except services.errors.account.NoPermissionError as e: | |||
| raise Forbidden(str(e)) | |||
| return dataset | |||
| @staticmethod | |||
| def hit_testing_args_check(args): | |||
| HitTestingService.hit_testing_args_check(args) | |||
| @staticmethod | |||
| def parse_args(): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("query", type=str, location="json") | |||
| parser.add_argument("retrieval_model", type=dict, required=False, location="json") | |||
| parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") | |||
| return parser.parse_args() | |||
| @staticmethod | |||
| def perform_hit_testing(dataset, args): | |||
| try: | |||
| response = HitTestingService.retrieve( | |||
| dataset=dataset, | |||
| query=args["query"], | |||
| account=current_user, | |||
| retrieval_model=args["retrieval_model"], | |||
| external_retrieval_model=args["external_retrieval_model"], | |||
| limit=10, | |||
| ) | |||
| return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} | |||
| except services.errors.index.IndexNotInitializedError: | |||
| raise DatasetNotInitializedError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| "No Embedding Model or Reranking Model available. Please configure a valid provider " | |||
| "in the Settings -> Model Provider." | |||
| ) | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(e.description) | |||
| except ValueError as e: | |||
| raise ValueError(str(e)) | |||
| except Exception as e: | |||
| logging.exception("Hit testing failed.") | |||
| raise InternalServerError(str(e)) | |||
| @@ -5,7 +5,6 @@ from libs.external_api import ExternalApi | |||
| bp = Blueprint("service_api", __name__, url_prefix="/v1") | |||
| api = ExternalApi(bp) | |||
| from . import index | |||
| from .app import app, audio, completion, conversation, file, message, workflow | |||
| from .dataset import dataset, document, segment | |||
| from .dataset import dataset, document, hit_testing, segment | |||
| @@ -0,0 +1,17 @@ | |||
| from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase | |||
| from controllers.service_api import api | |||
| from controllers.service_api.wraps import DatasetApiResource | |||
| class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): | |||
| def post(self, tenant_id, dataset_id): | |||
| dataset_id_str = str(dataset_id) | |||
| dataset = self.get_and_validate_dataset(dataset_id_str) | |||
| args = self.parse_args() | |||
| self.hit_testing_args_check(args) | |||
| return self.perform_hit_testing(dataset, args) | |||
| api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing") | |||
| @@ -1050,6 +1050,151 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from | |||
| --- | |||
| <Heading | |||
| url='/datasets/{dataset_id}/hit_testing' | |||
| method='POST' | |||
| title='Dataset hit testing' | |||
| name='#dataset_hit_testing' | |||
| /> | |||
| <Row> | |||
| <Col> | |||
| ### Path | |||
| <Properties> | |||
| <Property name='dataset_id' type='string' key='dataset_id'> | |||
| Dataset ID | |||
| </Property> | |||
| </Properties> | |||
| ### Request Body | |||
| <Properties> | |||
| <Property name='query' type='string' key='query'> | |||
| retrieval keywordc | |||
| </Property> | |||
| <Property name='retrieval_model' type='object' key='retrieval_model'> | |||
| retrieval keyword(Optional, if not filled, it will be recalled according to the default method) | |||
| - <code>search_method</code> (text) Search method: One of the following four keywords is required | |||
| - <code>keyword_search</code> Keyword search | |||
| - <code>semantic_search</code> Semantic search | |||
| - <code>full_text_search</code> Full-text search | |||
| - <code>hybrid_search</code> Hybrid search | |||
| - <code>reranking_enable</code> (bool) Whether to enable reranking, optional, required if the search mode is semantic_search or hybrid_search | |||
| - <code>reranking_mode</code> (object) Rerank model configuration, optional, required if reranking is enabled | |||
| - <code>reranking_provider_name</code> (string) Rerank model provider | |||
| - <code>reranking_model_name</code> (string) Rerank model name | |||
| - <code>weights</code> (double) Semantic search weight setting in hybrid search mode | |||
| - <code>top_k</code> (integer) Number of results to return, optional | |||
| - <code>score_threshold_enabled</code> (bool) Whether to enable score threshold | |||
| - <code>score_threshold</code> (double) Score threshold | |||
| </Property> | |||
| <Property name='external_retrieval_model' type='object' key='external_retrieval_model'> | |||
| Unused field | |||
| </Property> | |||
| </Properties> | |||
| </Col> | |||
| <Col sticky> | |||
| <CodeGroup | |||
| title="Request" | |||
| tag="POST" | |||
| label="/datasets/{dataset_id}/hit_testing" | |||
| targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{ | |||
| "query": "test", | |||
| "retrieval_model": { | |||
| "search_method": "keyword_search", | |||
| "reranking_enable": false, | |||
| "reranking_mode": null, | |||
| "reranking_model": { | |||
| "reranking_provider_name": "", | |||
| "reranking_model_name": "" | |||
| }, | |||
| "weights": null, | |||
| "top_k": 1, | |||
| "score_threshold_enabled": false, | |||
| "score_threshold": null | |||
| } | |||
| }'`} | |||
| > | |||
| ```bash {{ title: 'cURL' }} | |||
| curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \ | |||
| --header 'Authorization: Bearer {api_key}' \ | |||
| --header 'Content-Type: application/json' \ | |||
| --data-raw '{ | |||
| "query": "test", | |||
| "retrieval_model": { | |||
| "search_method": "keyword_search", | |||
| "reranking_enable": false, | |||
| "reranking_mode": null, | |||
| "reranking_model": { | |||
| "reranking_provider_name": "", | |||
| "reranking_model_name": "" | |||
| }, | |||
| "weights": null, | |||
| "top_k": 2, | |||
| "score_threshold_enabled": false, | |||
| "score_threshold": null | |||
| } | |||
| }' | |||
| ``` | |||
| </CodeGroup> | |||
| <CodeGroup title="Response"> | |||
| ```json {{ title: 'Response' }} | |||
| { | |||
| "query": { | |||
| "content": "test" | |||
| }, | |||
| "records": [ | |||
| { | |||
| "segment": { | |||
| "id": "7fa6f24f-8679-48b3-bc9d-bdf28d73f218", | |||
| "position": 1, | |||
| "document_id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2", | |||
| "content": "Operation guide", | |||
| "answer": null, | |||
| "word_count": 847, | |||
| "tokens": 280, | |||
| "keywords": [ | |||
| "install", | |||
| "java", | |||
| "base", | |||
| "scripts", | |||
| "jdk", | |||
| "manual", | |||
| "internal", | |||
| "opens", | |||
| "add", | |||
| "vmoptions" | |||
| ], | |||
| "index_node_id": "39dd8443-d960-45a8-bb46-7275ad7fbc8e", | |||
| "index_node_hash": "0189157697b3c6a418ccf8264a09699f25858975578f3467c76d6bfc94df1d73", | |||
| "hit_count": 0, | |||
| "enabled": true, | |||
| "disabled_at": null, | |||
| "disabled_by": null, | |||
| "status": "completed", | |||
| "created_by": "dbcb1ab5-90c8-41a7-8b78-73b235eb6f6f", | |||
| "created_at": 1728734540, | |||
| "indexing_at": 1728734552, | |||
| "completed_at": 1728734584, | |||
| "error": null, | |||
| "stopped_at": null, | |||
| "document": { | |||
| "id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2", | |||
| "data_source_type": "upload_file", | |||
| "name": "readme.txt", | |||
| "doc_type": null | |||
| } | |||
| }, | |||
| "score": 3.730463140527718e-05, | |||
| "tsne_position": null | |||
| } | |||
| ] | |||
| } | |||
| ``` | |||
| </CodeGroup> | |||
| </Col> | |||
| </Row> | |||
| --- | |||
| <Row> | |||
| <Col> | |||
| ### Error message | |||
| @@ -1049,6 +1049,152 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from | |||
| </Col> | |||
| </Row> | |||
| --- | |||
| <Heading | |||
| url='/datasets/{dataset_id}/hit_testing' | |||
| method='POST' | |||
| title='知识库召回测试' | |||
| name='#dataset_hit_testing' | |||
| /> | |||
| <Row> | |||
| <Col> | |||
| ### Path | |||
| <Properties> | |||
| <Property name='dataset_id' type='string' key='dataset_id'> | |||
| 知识库 ID | |||
| </Property> | |||
| </Properties> | |||
| ### Request Body | |||
| <Properties> | |||
| <Property name='query' type='string' key='query'> | |||
| 召回关键词 | |||
| </Property> | |||
| <Property name='retrieval_model' type='object' key='retrieval_model'> | |||
| 召回参数(选填,如不填,按照默认方式召回) | |||
| - <code>search_method</code> (text) 检索方法:以下三个关键字之一,必填 | |||
| - <code>keyword_search</code> 关键字检索 | |||
| - <code>semantic_search</code> 语义检索 | |||
| - <code>full_text_search</code> 全文检索 | |||
| - <code>hybrid_search</code> 混合检索 | |||
| - <code>reranking_enable</code> (bool) 是否启用 Reranking,非必填,如果检索模式为semantic_search模式或者hybrid_search则传值 | |||
| - <code>reranking_mode</code> (object) Rerank模型配置,非必填,如果启用了 reranking 则传值 | |||
| - <code>reranking_provider_name</code> (string) Rerank 模型提供商 | |||
| - <code>reranking_model_name</code> (string) Rerank 模型名称 | |||
| - <code>weights</code> (double) 混合检索模式下语意检索的权重设置 | |||
| - <code>top_k</code> (integer) 返回结果数量,非必填 | |||
| - <code>score_threshold_enabled</code> (bool) 是否开启Score阈值 | |||
| - <code>score_threshold</code> (double) Score阈值 | |||
| </Property> | |||
| <Property name='external_retrieval_model' type='object' key='external_retrieval_model'> | |||
| 未启用字段 | |||
| </Property> | |||
| </Properties> | |||
| </Col> | |||
| <Col sticky> | |||
| <CodeGroup | |||
| title="Request" | |||
| tag="POST" | |||
| label="/datasets/{dataset_id}/hit_testing" | |||
| targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{ | |||
| "query": "test", | |||
| "retrieval_model": { | |||
| "search_method": "keyword_search", | |||
| "reranking_enable": false, | |||
| "reranking_mode": null, | |||
| "reranking_model": { | |||
| "reranking_provider_name": "", | |||
| "reranking_model_name": "" | |||
| }, | |||
| "weights": null, | |||
| "top_k": 1, | |||
| "score_threshold_enabled": false, | |||
| "score_threshold": null | |||
| } | |||
| }'`} | |||
| > | |||
| ```bash {{ title: 'cURL' }} | |||
| curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \ | |||
| --header 'Authorization: Bearer {api_key}' \ | |||
| --header 'Content-Type: application/json' \ | |||
| --data-raw '{ | |||
| "query": "test", | |||
| "retrieval_model": { | |||
| "search_method": "keyword_search", | |||
| "reranking_enable": false, | |||
| "reranking_mode": null, | |||
| "reranking_model": { | |||
| "reranking_provider_name": "", | |||
| "reranking_model_name": "" | |||
| }, | |||
| "weights": null, | |||
| "top_k": 2, | |||
| "score_threshold_enabled": false, | |||
| "score_threshold": null | |||
| } | |||
| }' | |||
| ``` | |||
| </CodeGroup> | |||
| <CodeGroup title="Response"> | |||
| ```json {{ title: 'Response' }} | |||
| { | |||
| "query": { | |||
| "content": "test" | |||
| }, | |||
| "records": [ | |||
| { | |||
| "segment": { | |||
| "id": "7fa6f24f-8679-48b3-bc9d-bdf28d73f218", | |||
| "position": 1, | |||
| "document_id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2", | |||
| "content": "Operation guide", | |||
| "answer": null, | |||
| "word_count": 847, | |||
| "tokens": 280, | |||
| "keywords": [ | |||
| "install", | |||
| "java", | |||
| "base", | |||
| "scripts", | |||
| "jdk", | |||
| "manual", | |||
| "internal", | |||
| "opens", | |||
| "add", | |||
| "vmoptions" | |||
| ], | |||
| "index_node_id": "39dd8443-d960-45a8-bb46-7275ad7fbc8e", | |||
| "index_node_hash": "0189157697b3c6a418ccf8264a09699f25858975578f3467c76d6bfc94df1d73", | |||
| "hit_count": 0, | |||
| "enabled": true, | |||
| "disabled_at": null, | |||
| "disabled_by": null, | |||
| "status": "completed", | |||
| "created_by": "dbcb1ab5-90c8-41a7-8b78-73b235eb6f6f", | |||
| "created_at": 1728734540, | |||
| "indexing_at": 1728734552, | |||
| "completed_at": 1728734584, | |||
| "error": null, | |||
| "stopped_at": null, | |||
| "document": { | |||
| "id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2", | |||
| "data_source_type": "upload_file", | |||
| "name": "readme.txt", | |||
| "doc_type": null | |||
| } | |||
| }, | |||
| "score": 3.730463140527718e-05, | |||
| "tsne_position": null | |||
| } | |||
| ] | |||
| } | |||
| ``` | |||
| </CodeGroup> | |||
| </Col> | |||
| </Row> | |||
| --- | |||
| <Row> | |||