您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

hit_testing.py 3.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import logging
  2. from flask_login import login_required, current_user
  3. from flask_restful import Resource, reqparse, marshal, fields
  4. from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
  5. import services
  6. from controllers.console import api
  7. from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
  8. from controllers.console.setup import setup_required
  9. from controllers.console.wraps import account_initialization_required
  10. from libs.helper import TimestampField
  11. from services.dataset_service import DatasetService
  12. from services.hit_testing_service import HitTestingService
  13. document_fields = {
  14. 'id': fields.String,
  15. 'data_source_type': fields.String,
  16. 'name': fields.String,
  17. 'doc_type': fields.String,
  18. }
  19. segment_fields = {
  20. 'id': fields.String,
  21. 'position': fields.Integer,
  22. 'document_id': fields.String,
  23. 'content': fields.String,
  24. 'word_count': fields.Integer,
  25. 'tokens': fields.Integer,
  26. 'keywords': fields.List(fields.String),
  27. 'index_node_id': fields.String,
  28. 'index_node_hash': fields.String,
  29. 'hit_count': fields.Integer,
  30. 'enabled': fields.Boolean,
  31. 'disabled_at': TimestampField,
  32. 'disabled_by': fields.String,
  33. 'status': fields.String,
  34. 'created_by': fields.String,
  35. 'created_at': TimestampField,
  36. 'indexing_at': TimestampField,
  37. 'completed_at': TimestampField,
  38. 'error': fields.String,
  39. 'stopped_at': TimestampField,
  40. 'document': fields.Nested(document_fields),
  41. }
  42. hit_testing_record_fields = {
  43. 'segment': fields.Nested(segment_fields),
  44. 'score': fields.Float,
  45. 'tsne_position': fields.Raw
  46. }
  47. class HitTestingApi(Resource):
  48. @setup_required
  49. @login_required
  50. @account_initialization_required
  51. def post(self, dataset_id):
  52. dataset_id_str = str(dataset_id)
  53. dataset = DatasetService.get_dataset(dataset_id_str)
  54. if dataset is None:
  55. raise NotFound("Dataset not found.")
  56. try:
  57. DatasetService.check_dataset_permission(dataset, current_user)
  58. except services.errors.account.NoPermissionError as e:
  59. raise Forbidden(str(e))
  60. # only high quality dataset can be used for hit testing
  61. if dataset.indexing_technique != 'high_quality':
  62. raise HighQualityDatasetOnlyError()
  63. parser = reqparse.RequestParser()
  64. parser.add_argument('query', type=str, location='json')
  65. args = parser.parse_args()
  66. query = args['query']
  67. if not query or len(query) > 250:
  68. raise ValueError('Query is required and cannot exceed 250 characters')
  69. try:
  70. response = HitTestingService.retrieve(
  71. dataset=dataset,
  72. query=query,
  73. account=current_user,
  74. limit=10,
  75. )
  76. return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
  77. except services.errors.index.IndexNotInitializedError:
  78. raise DatasetNotInitializedError()
  79. except Exception as e:
  80. logging.exception("Hit testing failed.")
  81. raise InternalServerError(str(e))
  82. api.add_resource(HitTestingApi, '/datasets/<uuid:dataset_id>/hit-testing')