You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

rag_pipeline.py 4.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import logging
  2. from flask import request
  3. from flask_restx import Resource, reqparse
  4. from sqlalchemy.orm import Session
  5. from controllers.console import console_ns
  6. from controllers.console.wraps import (
  7. account_initialization_required,
  8. enterprise_license_required,
  9. knowledge_pipeline_publish_enabled,
  10. setup_required,
  11. )
  12. from extensions.ext_database import db
  13. from libs.login import login_required
  14. from models.dataset import PipelineCustomizedTemplate
  15. from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
  16. from services.rag_pipeline.rag_pipeline import RagPipelineService
  17. logger = logging.getLogger(__name__)
  18. def _validate_name(name):
  19. if not name or len(name) < 1 or len(name) > 40:
  20. raise ValueError("Name must be between 1 to 40 characters.")
  21. return name
  22. def _validate_description_length(description):
  23. if len(description) > 400:
  24. raise ValueError("Description cannot exceed 400 characters.")
  25. return description
  26. @console_ns.route("/rag/pipeline/templates")
  27. class PipelineTemplateListApi(Resource):
  28. @setup_required
  29. @login_required
  30. @account_initialization_required
  31. @enterprise_license_required
  32. def get(self):
  33. type = request.args.get("type", default="built-in", type=str)
  34. language = request.args.get("language", default="en-US", type=str)
  35. # get pipeline templates
  36. pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
  37. return pipeline_templates, 200
  38. @console_ns.route("/rag/pipeline/templates/<string:template_id>")
  39. class PipelineTemplateDetailApi(Resource):
  40. @setup_required
  41. @login_required
  42. @account_initialization_required
  43. @enterprise_license_required
  44. def get(self, template_id: str):
  45. type = request.args.get("type", default="built-in", type=str)
  46. rag_pipeline_service = RagPipelineService()
  47. pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
  48. return pipeline_template, 200
  49. @console_ns.route("/rag/pipeline/customized/templates/<string:template_id>")
  50. class CustomizedPipelineTemplateApi(Resource):
  51. @setup_required
  52. @login_required
  53. @account_initialization_required
  54. @enterprise_license_required
  55. def patch(self, template_id: str):
  56. parser = reqparse.RequestParser()
  57. parser.add_argument(
  58. "name",
  59. nullable=False,
  60. required=True,
  61. help="Name must be between 1 to 40 characters.",
  62. type=_validate_name,
  63. )
  64. parser.add_argument(
  65. "description",
  66. type=str,
  67. nullable=True,
  68. required=False,
  69. default="",
  70. )
  71. parser.add_argument(
  72. "icon_info",
  73. type=dict,
  74. location="json",
  75. nullable=True,
  76. )
  77. args = parser.parse_args()
  78. pipeline_template_info = PipelineTemplateInfoEntity(**args)
  79. RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
  80. return 200
  81. @setup_required
  82. @login_required
  83. @account_initialization_required
  84. @enterprise_license_required
  85. def delete(self, template_id: str):
  86. RagPipelineService.delete_customized_pipeline_template(template_id)
  87. return 200
  88. @setup_required
  89. @login_required
  90. @account_initialization_required
  91. @enterprise_license_required
  92. def post(self, template_id: str):
  93. with Session(db.engine) as session:
  94. template = (
  95. session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
  96. )
  97. if not template:
  98. raise ValueError("Customized pipeline template not found.")
  99. return {"data": template.yaml_content}, 200
  100. @console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish")
  101. class PublishCustomizedPipelineTemplateApi(Resource):
  102. @setup_required
  103. @login_required
  104. @account_initialization_required
  105. @enterprise_license_required
  106. @knowledge_pipeline_publish_enabled
  107. def post(self, pipeline_id: str):
  108. parser = reqparse.RequestParser()
  109. parser.add_argument(
  110. "name",
  111. nullable=False,
  112. required=True,
  113. help="Name must be between 1 to 40 characters.",
  114. type=_validate_name,
  115. )
  116. parser.add_argument(
  117. "description",
  118. type=str,
  119. nullable=True,
  120. required=False,
  121. default="",
  122. )
  123. parser.add_argument(
  124. "icon_info",
  125. type=dict,
  126. location="json",
  127. nullable=True,
  128. )
  129. args = parser.parse_args()
  130. rag_pipeline_service = RagPipelineService()
  131. rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
  132. return {"result": "success"}