You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

rag_pipeline.py 4.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import logging
  2. from flask import request
  3. from flask_restful import Resource, reqparse
  4. from sqlalchemy.orm import Session
  5. from controllers.console import api
  6. from controllers.console.wraps import (
  7. account_initialization_required,
  8. enterprise_license_required,
  9. setup_required,
  10. )
  11. from extensions.ext_database import db
  12. from libs.login import login_required
  13. from models.dataset import Pipeline, PipelineCustomizedTemplate
  14. from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
  15. from services.rag_pipeline.rag_pipeline import RagPipelineService
  16. from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
  17. logger = logging.getLogger(__name__)
  18. def _validate_name(name):
  19. if not name or len(name) < 1 or len(name) > 40:
  20. raise ValueError("Name must be between 1 to 40 characters.")
  21. return name
  22. def _validate_description_length(description):
  23. if len(description) > 400:
  24. raise ValueError("Description cannot exceed 400 characters.")
  25. return description
  26. class PipelineTemplateListApi(Resource):
  27. @setup_required
  28. @login_required
  29. @account_initialization_required
  30. @enterprise_license_required
  31. def get(self):
  32. type = request.args.get("type", default="built-in", type=str)
  33. language = request.args.get("language", default="en-US", type=str)
  34. # get pipeline templates
  35. pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
  36. return pipeline_templates, 200
  37. class PipelineTemplateDetailApi(Resource):
  38. @setup_required
  39. @login_required
  40. @account_initialization_required
  41. @enterprise_license_required
  42. def get(self, template_id: str):
  43. pipeline_template = RagPipelineService.get_pipeline_template_detail(template_id)
  44. return pipeline_template, 200
  45. class CustomizedPipelineTemplateApi(Resource):
  46. @setup_required
  47. @login_required
  48. @account_initialization_required
  49. @enterprise_license_required
  50. def patch(self, template_id: str):
  51. parser = reqparse.RequestParser()
  52. parser.add_argument(
  53. "name",
  54. nullable=False,
  55. required=True,
  56. help="Name must be between 1 to 40 characters.",
  57. type=_validate_name,
  58. )
  59. parser.add_argument(
  60. "description",
  61. type=str,
  62. nullable=True,
  63. required=False,
  64. default="",
  65. )
  66. parser.add_argument(
  67. "icon_info",
  68. type=dict,
  69. location="json",
  70. nullable=True,
  71. )
  72. args = parser.parse_args()
  73. pipeline_template_info = PipelineTemplateInfoEntity(**args)
  74. pipeline_template = RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
  75. return pipeline_template, 200
  76. @setup_required
  77. @login_required
  78. @account_initialization_required
  79. @enterprise_license_required
  80. def delete(self, template_id: str):
  81. RagPipelineService.delete_customized_pipeline_template(template_id)
  82. return 200
  83. @setup_required
  84. @login_required
  85. @account_initialization_required
  86. @enterprise_license_required
  87. def post(self, template_id: str):
  88. with Session(db.engine) as session:
  89. template = (
  90. session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
  91. )
  92. if not template:
  93. raise ValueError("Customized pipeline template not found.")
  94. pipeline = session.query(Pipeline).filter(Pipeline.id == template.pipeline_id).first()
  95. if not pipeline:
  96. raise ValueError("Pipeline not found.")
  97. dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline, include_secret=True)
  98. return {"data": dsl}, 200
  99. api.add_resource(
  100. PipelineTemplateListApi,
  101. "/rag/pipeline/templates",
  102. )
  103. api.add_resource(
  104. PipelineTemplateDetailApi,
  105. "/rag/pipeline/templates/<string:template_id>",
  106. )
  107. api.add_resource(
  108. CustomizedPipelineTemplateApi,
  109. "/rag/pipeline/customized/templates/<string:template_id>",
  110. )