Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

tool_label_manager.py 3.5KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from sqlalchemy import select
  2. from core.tools.__base.tool_provider import ToolProviderController
  3. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  4. from core.tools.custom_tool.provider import ApiToolProviderController
  5. from core.tools.entities.values import default_tool_label_name_list
  6. from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
  7. from extensions.ext_database import db
  8. from models.tools import ToolLabelBinding
  9. class ToolLabelManager:
  10. @classmethod
  11. def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]:
  12. """
  13. Filter tool labels
  14. """
  15. tool_labels = [label for label in tool_labels if label in default_tool_label_name_list]
  16. return list(set(tool_labels))
  17. @classmethod
  18. def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]):
  19. """
  20. Update tool labels
  21. """
  22. labels = cls.filter_tool_labels(labels)
  23. if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
  24. provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
  25. else:
  26. raise ValueError("Unsupported tool type")
  27. # delete old labels
  28. db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete()
  29. # insert new labels
  30. for label in labels:
  31. db.session.add(
  32. ToolLabelBinding(
  33. tool_id=provider_id,
  34. tool_type=controller.provider_type.value,
  35. label_name=label,
  36. )
  37. )
  38. db.session.commit()
  39. @classmethod
  40. def get_tool_labels(cls, controller: ToolProviderController) -> list[str]:
  41. """
  42. Get tool labels
  43. """
  44. if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
  45. provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
  46. elif isinstance(controller, BuiltinToolProviderController):
  47. return controller.tool_labels
  48. else:
  49. raise ValueError("Unsupported tool type")
  50. stmt = select(ToolLabelBinding.label_name).where(
  51. ToolLabelBinding.tool_id == provider_id,
  52. ToolLabelBinding.tool_type == controller.provider_type.value,
  53. )
  54. labels = db.session.scalars(stmt).all()
  55. return list(labels)
  56. @classmethod
  57. def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:
  58. """
  59. Get tools labels
  60. :param tool_providers: list of tool providers
  61. :return: dict of tool labels
  62. :key: tool id
  63. :value: list of tool labels
  64. """
  65. if not tool_providers:
  66. return {}
  67. for controller in tool_providers:
  68. if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
  69. raise ValueError("Unsupported tool type")
  70. provider_ids = []
  71. for controller in tool_providers:
  72. assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
  73. provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute]
  74. labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()
  75. tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}
  76. for label in labels:
  77. tool_labels[label.tool_id].append(label.label_name)
  78. return tool_labels