You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

database_retrieval.py 2.8KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from typing import Optional
  2. import yaml
  3. from extensions.ext_database import db
  4. from models.dataset import PipelineBuiltInTemplate
  5. from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
  6. from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
  7. class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
  8. """
  9. Retrieval pipeline template from database
  10. """
  11. def get_pipeline_templates(self, language: str) -> dict:
  12. result = self.fetch_pipeline_templates_from_db(language)
  13. return result
  14. def get_pipeline_template_detail(self, pipeline_id: str):
  15. result = self.fetch_pipeline_template_detail_from_db(pipeline_id)
  16. return result
  17. def get_type(self) -> str:
  18. return PipelineTemplateType.DATABASE
  19. @classmethod
  20. def fetch_pipeline_templates_from_db(cls, language: str) -> dict:
  21. """
  22. Fetch pipeline templates from db.
  23. :param language: language
  24. :return:
  25. """
  26. pipeline_built_in_templates: list[PipelineBuiltInTemplate] = (
  27. db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.language == language).all()
  28. )
  29. recommended_pipelines_results = []
  30. for pipeline_built_in_template in pipeline_built_in_templates:
  31. recommended_pipeline_result = {
  32. "id": pipeline_built_in_template.id,
  33. "name": pipeline_built_in_template.name,
  34. "description": pipeline_built_in_template.description,
  35. "icon": pipeline_built_in_template.icon,
  36. "copyright": pipeline_built_in_template.copyright,
  37. "privacy_policy": pipeline_built_in_template.privacy_policy,
  38. "position": pipeline_built_in_template.position,
  39. "chunk_structure": pipeline_built_in_template.chunk_structure,
  40. }
  41. recommended_pipelines_results.append(recommended_pipeline_result)
  42. return {"pipeline_templates": recommended_pipelines_results}
  43. @classmethod
  44. def fetch_pipeline_template_detail_from_db(cls, pipeline_id: str) -> Optional[dict]:
  45. """
  46. Fetch pipeline template detail from db.
  47. :param pipeline_id: Pipeline ID
  48. :return:
  49. """
  50. # is in public recommended list
  51. pipeline_template = (
  52. db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first()
  53. )
  54. if not pipeline_template:
  55. return None
  56. return {
  57. "id": pipeline_template.id,
  58. "name": pipeline_template.name,
  59. "icon": pipeline_template.icon,
  60. "chunk_structure": pipeline_template.chunk_structure,
  61. "export_data": yaml.safe_load(pipeline_template.yaml_content),
  62. }