Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

output_moderation.py 4.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import logging
  2. import threading
  3. import time
  4. from typing import Any, Optional
  5. from flask import Flask, current_app
  6. from pydantic import BaseModel, ConfigDict
  7. from configs import dify_config
  8. from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
  9. from core.app.entities.queue_entities import QueueMessageReplaceEvent
  10. from core.moderation.base import ModerationAction, ModerationOutputsResult
  11. from core.moderation.factory import ModerationFactory
  12. logger = logging.getLogger(__name__)
  13. class ModerationRule(BaseModel):
  14. type: str
  15. config: dict[str, Any]
  16. class OutputModeration(BaseModel):
  17. tenant_id: str
  18. app_id: str
  19. rule: ModerationRule
  20. queue_manager: AppQueueManager
  21. thread: Optional[threading.Thread] = None
  22. thread_running: bool = True
  23. buffer: str = ""
  24. is_final_chunk: bool = False
  25. final_output: Optional[str] = None
  26. model_config = ConfigDict(arbitrary_types_allowed=True)
  27. def should_direct_output(self) -> bool:
  28. return self.final_output is not None
  29. def get_final_output(self) -> str:
  30. return self.final_output or ""
  31. def append_new_token(self, token: str) -> None:
  32. self.buffer += token
  33. if not self.thread:
  34. self.thread = self.start_thread()
  35. def moderation_completion(self, completion: str, public_event: bool = False) -> tuple[str, bool]:
  36. self.buffer = completion
  37. self.is_final_chunk = True
  38. result = self.moderation(tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=completion)
  39. if not result or not result.flagged:
  40. return completion, False
  41. if result.action == ModerationAction.DIRECT_OUTPUT:
  42. final_output = result.preset_response
  43. else:
  44. final_output = result.text
  45. if public_event:
  46. self.queue_manager.publish(
  47. QueueMessageReplaceEvent(
  48. text=final_output, reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION
  49. ),
  50. PublishFrom.TASK_PIPELINE,
  51. )
  52. return final_output, True
  53. def start_thread(self) -> threading.Thread:
  54. buffer_size = dify_config.MODERATION_BUFFER_SIZE
  55. thread = threading.Thread(
  56. target=self.worker,
  57. kwargs={
  58. "flask_app": current_app._get_current_object(), # type: ignore
  59. "buffer_size": buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE,
  60. },
  61. )
  62. thread.start()
  63. return thread
  64. def stop_thread(self):
  65. if self.thread and self.thread.is_alive():
  66. self.thread_running = False
  67. def worker(self, flask_app: Flask, buffer_size: int):
  68. with flask_app.app_context():
  69. current_length = 0
  70. while self.thread_running:
  71. moderation_buffer = self.buffer
  72. buffer_length = len(moderation_buffer)
  73. if not self.is_final_chunk:
  74. chunk_length = buffer_length - current_length
  75. if 0 <= chunk_length < buffer_size:
  76. time.sleep(1)
  77. continue
  78. current_length = buffer_length
  79. result = self.moderation(
  80. tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=moderation_buffer
  81. )
  82. if not result or not result.flagged:
  83. continue
  84. if result.action == ModerationAction.DIRECT_OUTPUT:
  85. final_output = result.preset_response
  86. self.final_output = final_output
  87. else:
  88. final_output = result.text + self.buffer[len(moderation_buffer) :]
  89. # trigger replace event
  90. if self.thread_running:
  91. self.queue_manager.publish(
  92. QueueMessageReplaceEvent(
  93. text=final_output, reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION
  94. ),
  95. PublishFrom.TASK_PIPELINE,
  96. )
  97. if result.action == ModerationAction.DIRECT_OUTPUT:
  98. break
  99. def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
  100. try:
  101. moderation_factory = ModerationFactory(
  102. name=self.rule.type, app_id=app_id, tenant_id=tenant_id, config=self.rule.config
  103. )
  104. result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
  105. return result
  106. except Exception:
  107. logger.exception("Moderation Output error, app_id: %s", app_id)
  108. return None