You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

rag_pipeline.py 3.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import logging
  2. from flask import request
  3. from flask_restful import Resource, reqparse # type: ignore # type: ignore
  4. from controllers.console import api
  5. from controllers.console.wraps import (
  6. account_initialization_required,
  7. enterprise_license_required,
  8. setup_required,
  9. )
  10. from libs.login import login_required
  11. from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
  12. from services.rag_pipeline.rag_pipeline import RagPipelineService
  13. logger = logging.getLogger(__name__)
  14. def _validate_name(name):
  15. if not name or len(name) < 1 or len(name) > 40:
  16. raise ValueError("Name must be between 1 to 40 characters.")
  17. return name
  18. def _validate_description_length(description):
  19. if len(description) > 400:
  20. raise ValueError("Description cannot exceed 400 characters.")
  21. return description
  22. class PipelineTemplateListApi(Resource):
  23. @setup_required
  24. @login_required
  25. @account_initialization_required
  26. @enterprise_license_required
  27. def get(self):
  28. type = request.args.get("type", default="built-in", type=str, choices=["built-in", "customized"])
  29. language = request.args.get("language", default="en-US", type=str)
  30. # get pipeline templates
  31. pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
  32. return pipeline_templates, 200
  33. class PipelineTemplateDetailApi(Resource):
  34. @setup_required
  35. @login_required
  36. @account_initialization_required
  37. @enterprise_license_required
  38. def get(self, pipeline_id: str):
  39. pipeline_template = RagPipelineService.get_pipeline_template_detail(pipeline_id)
  40. return pipeline_template, 200
  41. class CustomizedPipelineTemplateApi(Resource):
  42. @setup_required
  43. @login_required
  44. @account_initialization_required
  45. @enterprise_license_required
  46. def patch(self, template_id: str):
  47. parser = reqparse.RequestParser()
  48. parser.add_argument(
  49. "name",
  50. nullable=False,
  51. required=True,
  52. help="Name must be between 1 to 40 characters.",
  53. type=_validate_name,
  54. )
  55. parser.add_argument(
  56. "description",
  57. type=str,
  58. nullable=True,
  59. required=False,
  60. default="",
  61. )
  62. parser.add_argument(
  63. "icon_info",
  64. type=dict,
  65. location="json",
  66. nullable=True,
  67. )
  68. args = parser.parse_args()
  69. pipeline_template_info = PipelineTemplateInfoEntity(**args)
  70. pipeline_template = RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
  71. return pipeline_template, 200
  72. @setup_required
  73. @login_required
  74. @account_initialization_required
  75. @enterprise_license_required
  76. def delete(self, template_id: str):
  77. RagPipelineService.delete_customized_pipeline_template(template_id)
  78. return 200
  79. api.add_resource(
  80. PipelineTemplateListApi,
  81. "/rag/pipeline/templates",
  82. )
  83. api.add_resource(
  84. PipelineTemplateDetailApi,
  85. "/rag/pipeline/templates/<string:pipeline_id>",
  86. )
  87. api.add_resource(
  88. CustomizedPipelineTemplateApi,
  89. "/rag/pipeline/templates/<string:template_id>",
  90. )