You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

rag_pipeline.py 4.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import logging
  2. import yaml
  3. from flask import request
  4. from flask_restful import Resource, reqparse
  5. from sqlalchemy.orm import Session
  6. from controllers.console import api
  7. from controllers.console.wraps import (
  8. account_initialization_required,
  9. enterprise_license_required,
  10. setup_required,
  11. )
  12. from extensions.ext_database import db
  13. from libs.login import login_required
  14. from models.dataset import PipelineCustomizedTemplate
  15. from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
  16. from services.rag_pipeline.rag_pipeline import RagPipelineService
  17. logger = logging.getLogger(__name__)
  18. def _validate_name(name):
  19. if not name or len(name) < 1 or len(name) > 40:
  20. raise ValueError("Name must be between 1 to 40 characters.")
  21. return name
  22. def _validate_description_length(description):
  23. if len(description) > 400:
  24. raise ValueError("Description cannot exceed 400 characters.")
  25. return description
  26. class PipelineTemplateListApi(Resource):
  27. @setup_required
  28. @login_required
  29. @account_initialization_required
  30. @enterprise_license_required
  31. def get(self):
  32. type = request.args.get("type", default="built-in", type=str)
  33. language = request.args.get("language", default="en-US", type=str)
  34. # get pipeline templates
  35. pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
  36. return pipeline_templates, 200
  37. class PipelineTemplateDetailApi(Resource):
  38. @setup_required
  39. @login_required
  40. @account_initialization_required
  41. @enterprise_license_required
  42. def get(self, template_id: str):
  43. pipeline_template = RagPipelineService.get_pipeline_template_detail(template_id)
  44. return pipeline_template, 200
  45. class CustomizedPipelineTemplateApi(Resource):
  46. @setup_required
  47. @login_required
  48. @account_initialization_required
  49. @enterprise_license_required
  50. def patch(self, template_id: str):
  51. parser = reqparse.RequestParser()
  52. parser.add_argument(
  53. "name",
  54. nullable=False,
  55. required=True,
  56. help="Name must be between 1 to 40 characters.",
  57. type=_validate_name,
  58. )
  59. parser.add_argument(
  60. "description",
  61. type=str,
  62. nullable=True,
  63. required=False,
  64. default="",
  65. )
  66. parser.add_argument(
  67. "icon_info",
  68. type=dict,
  69. location="json",
  70. nullable=True,
  71. )
  72. args = parser.parse_args()
  73. pipeline_template_info = PipelineTemplateInfoEntity(**args)
  74. RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
  75. return 200
  76. @setup_required
  77. @login_required
  78. @account_initialization_required
  79. @enterprise_license_required
  80. def delete(self, template_id: str):
  81. RagPipelineService.delete_customized_pipeline_template(template_id)
  82. return 200
  83. @setup_required
  84. @login_required
  85. @account_initialization_required
  86. @enterprise_license_required
  87. def post(self, template_id: str):
  88. with Session(db.engine) as session:
  89. template = (
  90. session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
  91. )
  92. if not template:
  93. raise ValueError("Customized pipeline template not found.")
  94. dsl = yaml.safe_load(template.yaml_content)
  95. return {"data": dsl}, 200
  96. class CustomizedPipelineTemplateApi(Resource):
  97. @setup_required
  98. @login_required
  99. @account_initialization_required
  100. @enterprise_license_required
  101. def post(self, pipeline_id: str):
  102. parser = reqparse.RequestParser()
  103. parser.add_argument(
  104. "name",
  105. nullable=False,
  106. required=True,
  107. help="Name must be between 1 to 40 characters.",
  108. type=_validate_name,
  109. )
  110. parser.add_argument(
  111. "description",
  112. type=str,
  113. nullable=True,
  114. required=False,
  115. default="",
  116. )
  117. parser.add_argument(
  118. "icon_info",
  119. type=dict,
  120. location="json",
  121. nullable=True,
  122. )
  123. args = parser.parse_args()
  124. rag_pipeline_service = RagPipelineService()
  125. RagPipelineService.publish_customized_pipeline_template(pipeline_id, args)
  126. return 200
  127. api.add_resource(
  128. PipelineTemplateListApi,
  129. "/rag/pipeline/templates",
  130. )
  131. api.add_resource(
  132. PipelineTemplateDetailApi,
  133. "/rag/pipeline/templates/<string:template_id>",
  134. )
  135. api.add_resource(
  136. CustomizedPipelineTemplateApi,
  137. "/rag/pipeline/customized/templates/<string:template_id>",
  138. )