You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

database_retrieval.py 3.3KB

6 months ago
5 months ago
5 months ago
6 months ago
5 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
5 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
5 months ago
5 months ago
5 months ago
5 months ago
5 months ago
5 months ago
5 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
5 months ago
5 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
5 months ago
6 months ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from typing import Optional
  2. from extensions.ext_database import db
  3. from models.dataset import Dataset, Pipeline, PipelineBuiltInTemplate
  4. from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
  5. from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
  6. class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
  7. """
  8. Retrieval pipeline template from database
  9. """
  10. def get_pipeline_templates(self, language: str) -> dict:
  11. result = self.fetch_pipeline_templates_from_db(language)
  12. return result
  13. def get_pipeline_template_detail(self, pipeline_id: str):
  14. result = self.fetch_pipeline_template_detail_from_db(pipeline_id)
  15. return result
  16. def get_type(self) -> str:
  17. return PipelineTemplateType.DATABASE
  18. @classmethod
  19. def fetch_pipeline_templates_from_db(cls, language: str) -> dict:
  20. """
  21. Fetch pipeline templates from db.
  22. :param language: language
  23. :return:
  24. """
  25. pipeline_built_in_templates: list[PipelineBuiltInTemplate] = (
  26. db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.language == language).all()
  27. )
  28. recommended_pipelines_results = []
  29. for pipeline_built_in_template in pipeline_built_in_templates:
  30. pipeline_model: Pipeline = pipeline_built_in_template.pipeline
  31. recommended_pipeline_result = {
  32. "id": pipeline_built_in_template.id,
  33. "name": pipeline_built_in_template.name,
  34. "pipeline_id": pipeline_model.id,
  35. "description": pipeline_built_in_template.description,
  36. "icon": pipeline_built_in_template.icon,
  37. "copyright": pipeline_built_in_template.copyright,
  38. "privacy_policy": pipeline_built_in_template.privacy_policy,
  39. "position": pipeline_built_in_template.position,
  40. }
  41. dataset: Dataset = pipeline_model.dataset
  42. if dataset:
  43. recommended_pipeline_result["chunk_structure"] = dataset.chunk_structure
  44. recommended_pipelines_results.append(recommended_pipeline_result)
  45. return {"pipeline_templates": recommended_pipelines_results}
  46. @classmethod
  47. def fetch_pipeline_template_detail_from_db(cls, pipeline_id: str) -> Optional[dict]:
  48. """
  49. Fetch pipeline template detail from db.
  50. :param pipeline_id: Pipeline ID
  51. :return:
  52. """
  53. from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
  54. # is in public recommended list
  55. pipeline_template = (
  56. db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == pipeline_id).first()
  57. )
  58. if not pipeline_template:
  59. return None
  60. # get app detail
  61. pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_template.pipeline_id).first()
  62. if not pipeline or not pipeline.is_public:
  63. return None
  64. return {
  65. "id": pipeline.id,
  66. "name": pipeline.name,
  67. "icon": pipeline.icon,
  68. "mode": pipeline.mode,
  69. "export_data": RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline),
  70. }