您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

customized_retrieval.py 2.9KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. from typing import Optional
  2. from flask_login import current_user
  3. import yaml
  4. from extensions.ext_database import db
  5. from models.dataset import PipelineCustomizedTemplate
  6. from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
  7. from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
  8. from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
  9. class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
  10. """
  11. Retrieval recommended app from database
  12. """
  13. def get_pipeline_templates(self, language: str) -> dict:
  14. result = self.fetch_pipeline_templates_from_customized(
  15. tenant_id=current_user.current_tenant_id, language=language
  16. )
  17. return result
  18. def get_pipeline_template_detail(self, template_id: str):
  19. result = self.fetch_pipeline_template_detail_from_db(template_id)
  20. return result
  21. def get_type(self) -> str:
  22. return PipelineTemplateType.CUSTOMIZED
  23. @classmethod
  24. def fetch_pipeline_templates_from_customized(cls, tenant_id: str, language: str) -> dict:
  25. """
  26. Fetch pipeline templates from db.
  27. :param tenant_id: tenant id
  28. :param language: language
  29. :return:
  30. """
  31. pipeline_customized_templates = (
  32. db.session.query(PipelineCustomizedTemplate)
  33. .filter(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language)
  34. .all()
  35. )
  36. recommended_pipelines_results = []
  37. for pipeline_customized_template in pipeline_customized_templates:
  38. recommended_pipeline_result = {
  39. "id": pipeline_customized_template.id,
  40. "name": pipeline_customized_template.name,
  41. "description": pipeline_customized_template.description,
  42. "icon": pipeline_customized_template.icon,
  43. "position": pipeline_customized_template.position,
  44. "chunk_structure": pipeline_customized_template.chunk_structure,
  45. }
  46. recommended_pipelines_results.append(recommended_pipeline_result)
  47. return {"pipeline_templates": recommended_pipelines_results}
  48. @classmethod
  49. def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]:
  50. """
  51. Fetch pipeline template detail from db.
  52. :param template_id: Template ID
  53. :return:
  54. """
  55. pipeline_template = (
  56. db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
  57. )
  58. if not pipeline_template:
  59. return None
  60. return {
  61. "id": pipeline_template.id,
  62. "name": pipeline_template.name,
  63. "icon": pipeline_template.icon,
  64. "export_data": yaml.safe_load(pipeline_template.yaml_content),
  65. }