You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

rag_pipeline.py 5.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import logging
  2. from flask import request
  3. from flask_restful import Resource, reqparse
  4. from sqlalchemy.orm import Session
  5. from controllers.console import api
  6. from controllers.console.wraps import (
  7. account_initialization_required,
  8. enterprise_license_required,
  9. knowledge_pipeline_publish_enabled,
  10. setup_required,
  11. )
  12. from extensions.ext_database import db
  13. from libs.login import login_required
  14. from models.dataset import PipelineCustomizedTemplate
  15. from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
  16. from services.rag_pipeline.rag_pipeline import RagPipelineService
  17. logger = logging.getLogger(__name__)
  18. def _validate_name(name):
  19. if not name or len(name) < 1 or len(name) > 40:
  20. raise ValueError("Name must be between 1 to 40 characters.")
  21. return name
  22. def _validate_description_length(description):
  23. if len(description) > 400:
  24. raise ValueError("Description cannot exceed 400 characters.")
  25. return description
  26. class PipelineTemplateListApi(Resource):
  27. @setup_required
  28. @login_required
  29. @account_initialization_required
  30. @enterprise_license_required
  31. def get(self):
  32. type = request.args.get("type", default="built-in", type=str)
  33. language = request.args.get("language", default="en-US", type=str)
  34. # get pipeline templates
  35. pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
  36. return pipeline_templates, 200
  37. class PipelineTemplateDetailApi(Resource):
  38. @setup_required
  39. @login_required
  40. @account_initialization_required
  41. @enterprise_license_required
  42. def get(self, template_id: str):
  43. type = request.args.get("type", default="built-in", type=str)
  44. rag_pipeline_service = RagPipelineService()
  45. pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
  46. return pipeline_template, 200
  47. class CustomizedPipelineTemplateApi(Resource):
  48. @setup_required
  49. @login_required
  50. @account_initialization_required
  51. @enterprise_license_required
  52. def patch(self, template_id: str):
  53. parser = reqparse.RequestParser()
  54. parser.add_argument(
  55. "name",
  56. nullable=False,
  57. required=True,
  58. help="Name must be between 1 to 40 characters.",
  59. type=_validate_name,
  60. )
  61. parser.add_argument(
  62. "description",
  63. type=str,
  64. nullable=True,
  65. required=False,
  66. default="",
  67. )
  68. parser.add_argument(
  69. "icon_info",
  70. type=dict,
  71. location="json",
  72. nullable=True,
  73. )
  74. args = parser.parse_args()
  75. pipeline_template_info = PipelineTemplateInfoEntity(**args)
  76. RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
  77. return 200
  78. @setup_required
  79. @login_required
  80. @account_initialization_required
  81. @enterprise_license_required
  82. def delete(self, template_id: str):
  83. RagPipelineService.delete_customized_pipeline_template(template_id)
  84. return 200
  85. @setup_required
  86. @login_required
  87. @account_initialization_required
  88. @enterprise_license_required
  89. def post(self, template_id: str):
  90. with Session(db.engine) as session:
  91. template = (
  92. session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
  93. )
  94. if not template:
  95. raise ValueError("Customized pipeline template not found.")
  96. return {"data": template.yaml_content}, 200
  97. class PublishCustomizedPipelineTemplateApi(Resource):
  98. @setup_required
  99. @login_required
  100. @account_initialization_required
  101. @enterprise_license_required
  102. @knowledge_pipeline_publish_enabled
  103. def post(self, pipeline_id: str):
  104. parser = reqparse.RequestParser()
  105. parser.add_argument(
  106. "name",
  107. nullable=False,
  108. required=True,
  109. help="Name must be between 1 to 40 characters.",
  110. type=_validate_name,
  111. )
  112. parser.add_argument(
  113. "description",
  114. type=str,
  115. nullable=True,
  116. required=False,
  117. default="",
  118. )
  119. parser.add_argument(
  120. "icon_info",
  121. type=dict,
  122. location="json",
  123. nullable=True,
  124. )
  125. args = parser.parse_args()
  126. rag_pipeline_service = RagPipelineService()
  127. rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
  128. return {"result": "success"}
  129. api.add_resource(
  130. PipelineTemplateListApi,
  131. "/rag/pipeline/templates",
  132. )
  133. api.add_resource(
  134. PipelineTemplateDetailApi,
  135. "/rag/pipeline/templates/<string:template_id>",
  136. )
  137. api.add_resource(
  138. CustomizedPipelineTemplateApi,
  139. "/rag/pipeline/customized/templates/<string:template_id>",
  140. )
  141. api.add_resource(
  142. PublishCustomizedPipelineTemplateApi,
  143. "/rag/pipelines/<string:pipeline_id>/customized/publish",
  144. )