| @@ -61,6 +61,8 @@ DEFAULTS = { | |||
| 'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000, | |||
| 'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20, | |||
| 'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100, | |||
| 'HOSTED_MODERATION_ENABLED': 'False', | |||
| 'HOSTED_MODERATION_PROVIDERS': '', | |||
| 'TENANT_DOCUMENT_COUNT': 100, | |||
| 'CLEAN_DAY_SETTING': 30, | |||
| 'UPLOAD_FILE_SIZE_LIMIT': 15, | |||
| @@ -230,6 +232,9 @@ class Config: | |||
| self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY')) | |||
| self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY')) | |||
| self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED') | |||
| self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS') | |||
| self.STRIPE_API_KEY = get_env('STRIPE_API_KEY') | |||
| self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET') | |||
| @@ -16,6 +16,8 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti | |||
| from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent | |||
| from langchain.agents import AgentExecutor as LCAgentExecutor | |||
| from core.helper import moderation | |||
| from core.model_providers.error import LLMError | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | |||
| @@ -116,6 +118,18 @@ class AgentExecutor: | |||
| return self.agent.should_use_agent(query) | |||
| def run(self, query: str) -> AgentExecuteResult: | |||
| moderation_result = moderation.check_moderation( | |||
| self.configuration.model_instance.model_provider, | |||
| query | |||
| ) | |||
| if not moderation_result: | |||
| return AgentExecuteResult( | |||
| output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.", | |||
| strategy=self.configuration.strategy, | |||
| configuration=self.configuration | |||
| ) | |||
| agent_executor = LCAgentExecutor.from_agent_and_tools( | |||
| agent=self.agent, | |||
| tools=self.configuration.tools, | |||
| @@ -128,7 +142,9 @@ class AgentExecutor: | |||
| try: | |||
| output = agent_executor.run(query) | |||
| except Exception: | |||
| except LLMError as ex: | |||
| raise ex | |||
| except Exception as ex: | |||
| logging.exception("agent_executor run failed") | |||
| output = None | |||
| @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union, Optional | |||
| from langchain.agents import openai_functions_agent, openai_functions_multi_agent | |||
| from langchain.callbacks.base import BaseCallbackHandler | |||
| from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration | |||
| from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage | |||
| from core.callback_handler.entity.agent_loop import AgentLoop | |||
| from core.conversation_message_task import ConversationMessageTask | |||
| @@ -18,9 +18,9 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| """Callback Handler that prints to std out.""" | |||
| raise_error: bool = True | |||
| def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None: | |||
| def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None: | |||
| """Initialize callback handler.""" | |||
| self.model_instant = model_instant | |||
| self.model_instance = model_instance | |||
| self.conversation_message_task = conversation_message_task | |||
| self._agent_loops = [] | |||
| self._current_loop = None | |||
| @@ -46,6 +46,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| """Whether to ignore chain callbacks.""" | |||
| return True | |||
| def on_chat_model_start( | |||
| self, | |||
| serialized: Dict[str, Any], | |||
| messages: List[List[BaseMessage]], | |||
| **kwargs: Any | |||
| ) -> Any: | |||
| if not self._current_loop: | |||
| # Agent start with a LLM query | |||
| self._current_loop = AgentLoop( | |||
| position=len(self._agent_loops) + 1, | |||
| prompt="\n".join([message.content for message in messages[0]]), | |||
| status='llm_started', | |||
| started_at=time.perf_counter() | |||
| ) | |||
| def on_llm_start( | |||
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |||
| ) -> None: | |||
| @@ -70,7 +85,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| if response.llm_output: | |||
| self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens'] | |||
| else: | |||
| self._current_loop.prompt_tokens = self.model_instant.get_num_tokens( | |||
| self._current_loop.prompt_tokens = self.model_instance.get_num_tokens( | |||
| [PromptMessage(content=self._current_loop.prompt)] | |||
| ) | |||
| completion_generation = response.generations[0][0] | |||
| @@ -87,7 +102,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| if response.llm_output: | |||
| self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens'] | |||
| else: | |||
| self._current_loop.completion_tokens = self.model_instant.get_num_tokens( | |||
| self._current_loop.completion_tokens = self.model_instance.get_num_tokens( | |||
| [PromptMessage(content=self._current_loop.completion)] | |||
| ) | |||
| @@ -162,7 +177,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at | |||
| self.conversation_message_task.on_agent_end( | |||
| self._message_agent_thought, self.model_instant, self._current_loop | |||
| self._message_agent_thought, self.model_instance, self._current_loop | |||
| ) | |||
| self._agent_loops.append(self._current_loop) | |||
| @@ -193,7 +208,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| ) | |||
| self.conversation_message_task.on_agent_end( | |||
| self._message_agent_thought, self.model_instant, self._current_loop | |||
| self._message_agent_thought, self.model_instance, self._current_loop | |||
| ) | |||
| self._agent_loops.append(self._current_loop) | |||
| @@ -6,4 +6,3 @@ class LLMMessage(BaseModel): | |||
| prompt_tokens: int = 0 | |||
| completion: str = '' | |||
| completion_tokens: int = 0 | |||
| latency: float = 0.0 | |||
| @@ -1,5 +1,4 @@ | |||
| import logging | |||
| import time | |||
| from typing import Any, Dict, List, Union | |||
| from langchain.callbacks.base import BaseCallbackHandler | |||
| @@ -32,7 +31,6 @@ class LLMCallbackHandler(BaseCallbackHandler): | |||
| messages: List[List[BaseMessage]], | |||
| **kwargs: Any | |||
| ) -> Any: | |||
| self.start_at = time.perf_counter() | |||
| real_prompts = [] | |||
| for message in messages[0]: | |||
| if message.type == 'human': | |||
| @@ -53,8 +51,6 @@ class LLMCallbackHandler(BaseCallbackHandler): | |||
| def on_llm_start( | |||
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |||
| ) -> None: | |||
| self.start_at = time.perf_counter() | |||
| self.llm_message.prompt = [{ | |||
| "role": 'user', | |||
| "text": prompts[0] | |||
| @@ -63,9 +59,6 @@ class LLMCallbackHandler(BaseCallbackHandler): | |||
| self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])]) | |||
| def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |||
| end_at = time.perf_counter() | |||
| self.llm_message.latency = end_at - self.start_at | |||
| if not self.conversation_message_task.streaming: | |||
| self.conversation_message_task.append_message_text(response.generations[0][0].text) | |||
| self.llm_message.completion = response.generations[0][0].text | |||
| @@ -89,8 +82,6 @@ class LLMCallbackHandler(BaseCallbackHandler): | |||
| """Do nothing.""" | |||
| if isinstance(error, ConversationTaskStoppedException): | |||
| if self.conversation_message_task.streaming: | |||
| end_at = time.perf_counter() | |||
| self.llm_message.latency = end_at - self.start_at | |||
| self.llm_message.completion_tokens = self.model_instance.get_num_tokens( | |||
| [PromptMessage(content=self.llm_message.completion)] | |||
| ) | |||
| @@ -1,15 +1,38 @@ | |||
| import enum | |||
| import logging | |||
| from typing import List, Dict, Optional, Any | |||
| import openai | |||
| from flask import current_app | |||
| from langchain.callbacks.manager import CallbackManagerForChainRun | |||
| from langchain.chains.base import Chain | |||
| from openai import InvalidRequestError | |||
| from openai.error import APIConnectionError, APIError, ServiceUnavailableError, Timeout, RateLimitError, \ | |||
| AuthenticationError, OpenAIError | |||
| from pydantic import BaseModel | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.model_providers.models.moderation import openai_moderation | |||
| class SensitiveWordAvoidanceRule(BaseModel): | |||
| class Type(enum.Enum): | |||
| MODERATION = "moderation" | |||
| KEYWORDS = "keywords" | |||
| type: Type | |||
| canned_response: str = 'Your content violates our usage policy. Please revise and try again.' | |||
| extra_params: dict = {} | |||
| class SensitiveWordAvoidanceChain(Chain): | |||
| input_key: str = "input" #: :meta private: | |||
| output_key: str = "output" #: :meta private: | |||
| sensitive_words: List[str] = [] | |||
| canned_response: str = None | |||
| model_instance: BaseLLM | |||
| sensitive_word_avoidance_rule: SensitiveWordAvoidanceRule | |||
| @property | |||
| def _chain_type(self) -> str: | |||
| @@ -31,11 +54,24 @@ class SensitiveWordAvoidanceChain(Chain): | |||
| """ | |||
| return [self.output_key] | |||
| def _check_sensitive_word(self, text: str) -> str: | |||
| for word in self.sensitive_words: | |||
| def _check_sensitive_word(self, text: str) -> bool: | |||
| for word in self.sensitive_word_avoidance_rule.extra_params.get('sensitive_words', []): | |||
| if word in text: | |||
| return self.canned_response | |||
| return text | |||
| return False | |||
| return True | |||
| def _check_moderation(self, text: str) -> bool: | |||
| moderation_model_instance = ModelFactory.get_moderation_model( | |||
| tenant_id=self.model_instance.model_provider.provider.tenant_id, | |||
| model_provider_name='openai', | |||
| model_name=openai_moderation.DEFAULT_MODEL | |||
| ) | |||
| try: | |||
| return moderation_model_instance.run(text=text) | |||
| except Exception as ex: | |||
| logging.exception(ex) | |||
| raise LLMBadRequestError('Rate limit exceeded, please try again later.') | |||
| def _call( | |||
| self, | |||
| @@ -43,5 +79,13 @@ class SensitiveWordAvoidanceChain(Chain): | |||
| run_manager: Optional[CallbackManagerForChainRun] = None, | |||
| ) -> Dict[str, Any]: | |||
| text = inputs[self.input_key] | |||
| output = self._check_sensitive_word(text) | |||
| return {self.output_key: output} | |||
| if self.sensitive_word_avoidance_rule.type == SensitiveWordAvoidanceRule.Type.KEYWORDS: | |||
| result = self._check_sensitive_word(text) | |||
| else: | |||
| result = self._check_moderation(text) | |||
| if not result: | |||
| raise LLMBadRequestError(self.sensitive_word_avoidance_rule.canned_response) | |||
| return {self.output_key: text} | |||
| @@ -1,9 +1,7 @@ | |||
| import json | |||
| import logging | |||
| import re | |||
| from typing import Optional, List, Union, Tuple | |||
| from typing import Optional, List, Union | |||
| from langchain.schema import BaseMessage | |||
| from requests.exceptions import ChunkedEncodingError | |||
| from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy | |||
| @@ -14,11 +12,10 @@ from core.model_providers.error import LLMBadRequestError | |||
| from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ | |||
| ReadOnlyConversationTokenDBBufferSharedMemory | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages | |||
| from core.model_providers.models.entity.message import PromptMessage | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.orchestrator_rule_parser import OrchestratorRuleParser | |||
| from core.prompt.prompt_builder import PromptBuilder | |||
| from core.prompt.prompt_template import JinjaPromptTemplate | |||
| from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT | |||
| from models.dataset import DocumentSegment, Dataset, Document | |||
| from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser | |||
| @@ -81,7 +78,7 @@ class Completion: | |||
| # parse sensitive_word_avoidance_chain | |||
| chain_callback = MainChainGatherCallbackHandler(conversation_message_task) | |||
| sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback]) | |||
| sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(final_model_instance, [chain_callback]) | |||
| if sensitive_word_avoidance_chain: | |||
| query = sensitive_word_avoidance_chain.run(query) | |||
| @@ -1,5 +1,5 @@ | |||
| import decimal | |||
| import json | |||
| import time | |||
| from typing import Optional, Union, List | |||
| from core.callback_handler.entity.agent_loop import AgentLoop | |||
| @@ -23,6 +23,8 @@ class ConversationMessageTask: | |||
| def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account, | |||
| inputs: dict, query: str, streaming: bool, model_instance: BaseLLM, | |||
| conversation: Optional[Conversation] = None, is_override: bool = False): | |||
| self.start_at = time.perf_counter() | |||
| self.task_id = task_id | |||
| self.app = app | |||
| @@ -61,6 +63,7 @@ class ConversationMessageTask: | |||
| ) | |||
| def init(self): | |||
| override_model_configs = None | |||
| if self.is_override: | |||
| override_model_configs = self.app_model_config.to_dict() | |||
| @@ -165,7 +168,7 @@ class ConversationMessageTask: | |||
| self.message.answer_tokens = answer_tokens | |||
| self.message.answer_unit_price = answer_unit_price | |||
| self.message.answer_price_unit = answer_price_unit | |||
| self.message.provider_response_latency = llm_message.latency | |||
| self.message.provider_response_latency = time.perf_counter() - self.start_at | |||
| self.message.total_price = total_price | |||
| db.session.commit() | |||
| @@ -220,18 +223,18 @@ class ConversationMessageTask: | |||
| return message_agent_thought | |||
| def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM, | |||
| def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM, | |||
| agent_loop: AgentLoop): | |||
| agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN) | |||
| agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN) | |||
| agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT) | |||
| agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT) | |||
| agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN) | |||
| agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN) | |||
| agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT) | |||
| agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT) | |||
| loop_message_tokens = agent_loop.prompt_tokens | |||
| loop_answer_tokens = agent_loop.completion_tokens | |||
| loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN) | |||
| loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT) | |||
| loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN) | |||
| loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT) | |||
| loop_total_price = loop_message_total_price + loop_answer_total_price | |||
| message_agent_thought.observation = agent_loop.tool_output | |||
| @@ -245,7 +248,7 @@ class ConversationMessageTask: | |||
| message_agent_thought.latency = agent_loop.latency | |||
| message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens | |||
| message_agent_thought.total_price = loop_total_price | |||
| message_agent_thought.currency = agent_model_instant.get_currency() | |||
| message_agent_thought.currency = agent_model_instance.get_currency() | |||
| db.session.flush() | |||
| def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj): | |||
| @@ -0,0 +1,32 @@ | |||
| import logging | |||
| import openai | |||
| from flask import current_app | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| from models.provider import ProviderType | |||
| def check_moderation(model_provider: BaseModelProvider, text: str) -> bool: | |||
| if current_app.config['HOSTED_MODERATION_ENABLED'] and current_app.config['HOSTED_MODERATION_PROVIDERS']: | |||
| moderation_providers = current_app.config['HOSTED_MODERATION_PROVIDERS'].split(',') | |||
| if model_provider.provider.provider_type == ProviderType.SYSTEM.value \ | |||
| and model_provider.provider_name in moderation_providers: | |||
| # 2000 text per chunk | |||
| length = 2000 | |||
| chunks = [text[i:i + length] for i in range(0, len(text), length)] | |||
| try: | |||
| moderation_result = openai.Moderation.create(input=chunks, | |||
| api_key=current_app.config['HOSTED_OPENAI_API_KEY']) | |||
| except Exception as ex: | |||
| logging.exception(ex) | |||
| raise LLMBadRequestError('Rate limit exceeded, please try again later.') | |||
| for result in moderation_result.results: | |||
| if result['flagged'] is True: | |||
| return False | |||
| return True | |||
| @@ -8,6 +8,7 @@ from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.embedding.base import BaseEmbedding | |||
| from core.model_providers.models.entity.model_params import ModelKwargs, ModelType | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.model_providers.models.moderation.base import BaseModeration | |||
| from core.model_providers.models.speech2text.base import BaseSpeech2Text | |||
| from extensions.ext_database import db | |||
| from models.provider import TenantDefaultModel | |||
| @@ -180,7 +181,7 @@ class ModelFactory: | |||
| def get_moderation_model(cls, | |||
| tenant_id: str, | |||
| model_provider_name: str, | |||
| model_name: str) -> Optional[BaseProviderModel]: | |||
| model_name: str) -> Optional[BaseModeration]: | |||
| """ | |||
| get moderation model. | |||
| @@ -10,6 +10,7 @@ from langchain.memory.chat_memory import BaseChatMemory | |||
| from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration | |||
| from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler | |||
| from core.helper import moderation | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages | |||
| from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules | |||
| @@ -116,6 +117,15 @@ class BaseLLM(BaseProviderModel): | |||
| :param callbacks: | |||
| :return: | |||
| """ | |||
| moderation_result = moderation.check_moderation( | |||
| self.model_provider, | |||
| "\n".join([message.content for message in messages]) | |||
| ) | |||
| if not moderation_result: | |||
| kwargs['fake_response'] = "I apologize for any confusion, " \ | |||
| "but I'm an AI assistant to be helpful, harmless, and honest." | |||
| if self.deduct_quota: | |||
| self.model_provider.check_quota_over_limit() | |||
| @@ -0,0 +1,29 @@ | |||
| from abc import abstractmethod | |||
| from typing import Any | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.entity.model_params import ModelType | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| class BaseModeration(BaseProviderModel): | |||
| name: str | |||
| type: ModelType = ModelType.MODERATION | |||
| def __init__(self, model_provider: BaseModelProvider, client: Any, name: str): | |||
| super().__init__(model_provider, client) | |||
| self.name = name | |||
| def run(self, text: str) -> bool: | |||
| try: | |||
| return self._run(text) | |||
| except Exception as ex: | |||
| raise self.handle_exceptions(ex) | |||
| @abstractmethod | |||
| def _run(self, text: str) -> bool: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| raise NotImplementedError | |||
| @@ -4,29 +4,35 @@ import openai | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ | |||
| LLMRateLimitError, LLMAuthorizationError | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.entity.model_params import ModelType | |||
| from core.model_providers.models.moderation.base import BaseModeration | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| DEFAULT_AUDIO_MODEL = 'whisper-1' | |||
| DEFAULT_MODEL = 'whisper-1' | |||
| class OpenAIModeration(BaseProviderModel): | |||
| type: ModelType = ModelType.MODERATION | |||
| class OpenAIModeration(BaseModeration): | |||
| def __init__(self, model_provider: BaseModelProvider, name: str): | |||
| super().__init__(model_provider, openai.Moderation) | |||
| super().__init__(model_provider, openai.Moderation, name) | |||
| def run(self, text): | |||
| def _run(self, text: str) -> bool: | |||
| credentials = self.model_provider.get_model_credentials( | |||
| model_name=DEFAULT_AUDIO_MODEL, | |||
| model_name=self.name, | |||
| model_type=self.type | |||
| ) | |||
| try: | |||
| return self._client.create(input=text, api_key=credentials['openai_api_key']) | |||
| except Exception as ex: | |||
| raise self.handle_exceptions(ex) | |||
| # 2000 text per chunk | |||
| length = 2000 | |||
| chunks = [text[i:i + length] for i in range(0, len(text), length)] | |||
| moderation_result = self._client.create(input=chunks, | |||
| api_key=credentials['openai_api_key']) | |||
| for result in moderation_result.results: | |||
| if result['flagged'] is True: | |||
| return False | |||
| return True | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| if isinstance(ex, openai.error.InvalidRequestError): | |||
| @@ -1,6 +1,7 @@ | |||
| import math | |||
| from typing import Optional | |||
| from flask import current_app | |||
| from langchain import WikipediaAPIWrapper | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.memory.chat_memory import BaseChatMemory | |||
| @@ -12,7 +13,7 @@ from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGa | |||
| from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler | |||
| from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler | |||
| from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler | |||
| from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain | |||
| from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain, SensitiveWordAvoidanceRule | |||
| from core.conversation_message_task import ConversationMessageTask | |||
| from core.model_providers.error import ProviderTokenNotInitError | |||
| from core.model_providers.model_factory import ModelFactory | |||
| @@ -26,6 +27,7 @@ from core.tool.web_reader_tool import WebReaderTool | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, DatasetProcessRule | |||
| from models.model import AppModelConfig | |||
| from models.provider import ProviderType | |||
| class OrchestratorRuleParser: | |||
| @@ -63,7 +65,7 @@ class OrchestratorRuleParser: | |||
| # add agent callback to record agent thoughts | |||
| agent_callback = AgentLoopGatherCallbackHandler( | |||
| model_instant=agent_model_instance, | |||
| model_instance=agent_model_instance, | |||
| conversation_message_task=conversation_message_task | |||
| ) | |||
| @@ -123,23 +125,45 @@ class OrchestratorRuleParser: | |||
| return chain | |||
| def to_sensitive_word_avoidance_chain(self, callbacks: Callbacks = None, **kwargs) \ | |||
| def to_sensitive_word_avoidance_chain(self, model_instance: BaseLLM, callbacks: Callbacks = None, **kwargs) \ | |||
| -> Optional[SensitiveWordAvoidanceChain]: | |||
| """ | |||
| Convert app sensitive word avoidance config to chain | |||
| :param model_instance: model instance | |||
| :param callbacks: callbacks for the chain | |||
| :param kwargs: | |||
| :return: | |||
| """ | |||
| if not self.app_model_config.sensitive_word_avoidance_dict: | |||
| return None | |||
| sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict | |||
| sensitive_words = sensitive_word_avoidance_config.get("words", "") | |||
| if sensitive_word_avoidance_config.get("enabled", False) and sensitive_words: | |||
| sensitive_word_avoidance_rule = None | |||
| if self.app_model_config.sensitive_word_avoidance_dict: | |||
| sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict | |||
| if sensitive_word_avoidance_config.get("enabled", False): | |||
| if sensitive_word_avoidance_config.get('type') == 'moderation': | |||
| sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule( | |||
| type=SensitiveWordAvoidanceRule.Type.MODERATION, | |||
| canned_response=sensitive_word_avoidance_config.get("canned_response") | |||
| if sensitive_word_avoidance_config.get("canned_response") | |||
| else 'Your content violates our usage policy. Please revise and try again.', | |||
| ) | |||
| else: | |||
| sensitive_words = sensitive_word_avoidance_config.get("words", "") | |||
| if sensitive_words: | |||
| sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule( | |||
| type=SensitiveWordAvoidanceRule.Type.KEYWORDS, | |||
| canned_response=sensitive_word_avoidance_config.get("canned_response") | |||
| if sensitive_word_avoidance_config.get("canned_response") | |||
| else 'Your content violates our usage policy. Please revise and try again.', | |||
| extra_params={ | |||
| 'sensitive_words': sensitive_words.split(','), | |||
| } | |||
| ) | |||
| if sensitive_word_avoidance_rule: | |||
| return SensitiveWordAvoidanceChain( | |||
| sensitive_words=sensitive_words.split(","), | |||
| canned_response=sensitive_word_avoidance_config.get("canned_response", ''), | |||
| model_instance=model_instance, | |||
| sensitive_word_avoidance_rule=sensitive_word_avoidance_rule, | |||
| output_key="sensitive_word_avoidance_output", | |||
| callbacks=callbacks, | |||
| **kwargs | |||
| @@ -2,7 +2,7 @@ import json | |||
| import os | |||
| from unittest.mock import patch | |||
| from core.model_providers.models.moderation.openai_moderation import OpenAIModeration, DEFAULT_AUDIO_MODEL | |||
| from core.model_providers.models.moderation.openai_moderation import OpenAIModeration, DEFAULT_MODEL | |||
| from core.model_providers.providers.openai_provider import OpenAIProvider | |||
| from models.provider import Provider, ProviderType | |||
| @@ -23,7 +23,7 @@ def get_mock_openai_moderation_model(): | |||
| openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key)) | |||
| return OpenAIModeration( | |||
| model_provider=openai_provider, | |||
| name=DEFAULT_AUDIO_MODEL | |||
| name=DEFAULT_MODEL | |||
| ) | |||
| @@ -36,5 +36,4 @@ def test_run(mock_decrypt): | |||
| model = get_mock_openai_moderation_model() | |||
| rst = model.run('hello') | |||
| assert isinstance(rst, dict) | |||
| assert 'id' in rst | |||
| assert rst is True | |||