| @@ -164,7 +164,7 @@ def initialize_extensions(app): | |||
| @login_manager.request_loader | |||
| def load_user_from_request(request_from_flask_login): | |||
| """Load user based on the request.""" | |||
| if request.blueprint not in ["console", "inner_api"]: | |||
| if request.blueprint not in {"console", "inner_api"}: | |||
| return None | |||
| # Check if the user_id contains a dot, indicating the old format | |||
| auth_header = request.headers.get("Authorization", "") | |||
| @@ -140,9 +140,9 @@ def reset_encrypt_key_pair(): | |||
| @click.command("vdb-migrate", help="migrate vector db.") | |||
| @click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.") | |||
| def vdb_migrate(scope: str): | |||
| if scope in ["knowledge", "all"]: | |||
| if scope in {"knowledge", "all"}: | |||
| migrate_knowledge_vector_database() | |||
| if scope in ["annotation", "all"]: | |||
| if scope in {"annotation", "all"}: | |||
| migrate_annotation_vector_database() | |||
| @@ -94,7 +94,7 @@ class ChatMessageTextApi(Resource): | |||
| message_id = args.get("message_id", None) | |||
| text = args.get("text", None) | |||
| if ( | |||
| app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] | |||
| app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} | |||
| and app_model.workflow | |||
| and app_model.workflow.features_dict | |||
| ): | |||
| @@ -71,7 +71,7 @@ class OAuthCallback(Resource): | |||
| account = _generate_account(provider, user_info) | |||
| # Check account status | |||
| if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: | |||
| if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: | |||
| return {"error": "Account is banned or closed."}, 403 | |||
| if account.status == AccountStatus.PENDING.value: | |||
| @@ -354,7 +354,7 @@ class DocumentIndexingEstimateApi(DocumentResource): | |||
| document_id = str(document_id) | |||
| document = self.get_document(dataset_id, document_id) | |||
| if document.indexing_status in ["completed", "error"]: | |||
| if document.indexing_status in {"completed", "error"}: | |||
| raise DocumentAlreadyFinishedError() | |||
| data_process_rule = document.dataset_process_rule | |||
| @@ -421,7 +421,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| info_list = [] | |||
| extract_settings = [] | |||
| for document in documents: | |||
| if document.indexing_status in ["completed", "error"]: | |||
| if document.indexing_status in {"completed", "error"}: | |||
| raise DocumentAlreadyFinishedError() | |||
| data_source_info = document.data_source_info_dict | |||
| # format document files info | |||
| @@ -665,7 +665,7 @@ class DocumentProcessingApi(DocumentResource): | |||
| db.session.commit() | |||
| elif action == "resume": | |||
| if document.indexing_status not in ["paused", "error"]: | |||
| if document.indexing_status not in {"paused", "error"}: | |||
| raise InvalidActionError("Document not in paused or error state.") | |||
| document.paused_by = None | |||
| @@ -81,7 +81,7 @@ class ChatTextApi(InstalledAppResource): | |||
| message_id = args.get("message_id", None) | |||
| text = args.get("text", None) | |||
| if ( | |||
| app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] | |||
| app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} | |||
| and app_model.workflow | |||
| and app_model.workflow.features_dict | |||
| ): | |||
| @@ -92,7 +92,7 @@ class ChatApi(InstalledAppResource): | |||
| def post(self, installed_app): | |||
| app_model = installed_app.app | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| parser = reqparse.RequestParser() | |||
| @@ -140,7 +140,7 @@ class ChatStopApi(InstalledAppResource): | |||
| def post(self, installed_app, task_id): | |||
| app_model = installed_app.app | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) | |||
| @@ -20,7 +20,7 @@ class ConversationListApi(InstalledAppResource): | |||
| def get(self, installed_app): | |||
| app_model = installed_app.app | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| parser = reqparse.RequestParser() | |||
| @@ -50,7 +50,7 @@ class ConversationApi(InstalledAppResource): | |||
| def delete(self, installed_app, c_id): | |||
| app_model = installed_app.app | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| conversation_id = str(c_id) | |||
| @@ -68,7 +68,7 @@ class ConversationRenameApi(InstalledAppResource): | |||
| def post(self, installed_app, c_id): | |||
| app_model = installed_app.app | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| conversation_id = str(c_id) | |||
| @@ -90,7 +90,7 @@ class ConversationPinApi(InstalledAppResource): | |||
| def patch(self, installed_app, c_id): | |||
| app_model = installed_app.app | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| conversation_id = str(c_id) | |||
| @@ -107,7 +107,7 @@ class ConversationUnPinApi(InstalledAppResource): | |||
| def patch(self, installed_app, c_id): | |||
| app_model = installed_app.app | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| conversation_id = str(c_id) | |||
| @@ -31,7 +31,7 @@ class InstalledAppsListApi(Resource): | |||
| "app_owner_tenant_id": installed_app.app_owner_tenant_id, | |||
| "is_pinned": installed_app.is_pinned, | |||
| "last_used_at": installed_app.last_used_at, | |||
| "editable": current_user.role in ["owner", "admin"], | |||
| "editable": current_user.role in {"owner", "admin"}, | |||
| "uninstallable": current_tenant_id == installed_app.app_owner_tenant_id, | |||
| } | |||
| for installed_app in installed_apps | |||
| @@ -40,7 +40,7 @@ class MessageListApi(InstalledAppResource): | |||
| app_model = installed_app.app | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| parser = reqparse.RequestParser() | |||
| @@ -125,7 +125,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource): | |||
| def get(self, installed_app, message_id): | |||
| app_model = installed_app.app | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| message_id = str(message_id) | |||
| @@ -43,7 +43,7 @@ class AppParameterApi(InstalledAppResource): | |||
| """Retrieve app parameters.""" | |||
| app_model = installed_app.app | |||
| if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: | |||
| if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: | |||
| workflow = app_model.workflow | |||
| if workflow is None: | |||
| raise AppUnavailableError() | |||
| @@ -194,7 +194,7 @@ class WebappLogoWorkspaceApi(Resource): | |||
| raise TooManyFilesError() | |||
| extension = file.filename.split(".")[-1] | |||
| if extension.lower() not in ["svg", "png"]: | |||
| if extension.lower() not in {"svg", "png"}: | |||
| raise UnsupportedFileTypeError() | |||
| try: | |||
| @@ -42,7 +42,7 @@ class AppParameterApi(Resource): | |||
| @marshal_with(parameters_fields) | |||
| def get(self, app_model: App): | |||
| """Retrieve app parameters.""" | |||
| if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: | |||
| if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: | |||
| workflow = app_model.workflow | |||
| if workflow is None: | |||
| raise AppUnavailableError() | |||
| @@ -79,7 +79,7 @@ class TextApi(Resource): | |||
| message_id = args.get("message_id", None) | |||
| text = args.get("text", None) | |||
| if ( | |||
| app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] | |||
| app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} | |||
| and app_model.workflow | |||
| and app_model.workflow.features_dict | |||
| ): | |||
| @@ -96,7 +96,7 @@ class ChatApi(Resource): | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) | |||
| def post(self, app_model: App, end_user: EndUser): | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| parser = reqparse.RequestParser() | |||
| @@ -144,7 +144,7 @@ class ChatStopApi(Resource): | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) | |||
| def post(self, app_model: App, end_user: EndUser, task_id): | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) | |||
| @@ -18,7 +18,7 @@ class ConversationApi(Resource): | |||
| @marshal_with(conversation_infinite_scroll_pagination_fields) | |||
| def get(self, app_model: App, end_user: EndUser): | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| parser = reqparse.RequestParser() | |||
| @@ -52,7 +52,7 @@ class ConversationDetailApi(Resource): | |||
| @marshal_with(simple_conversation_fields) | |||
| def delete(self, app_model: App, end_user: EndUser, c_id): | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| conversation_id = str(c_id) | |||
| @@ -69,7 +69,7 @@ class ConversationRenameApi(Resource): | |||
| @marshal_with(simple_conversation_fields) | |||
| def post(self, app_model: App, end_user: EndUser, c_id): | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| conversation_id = str(c_id) | |||
| @@ -76,7 +76,7 @@ class MessageListApi(Resource): | |||
| @marshal_with(message_infinite_scroll_pagination_fields) | |||
| def get(self, app_model: App, end_user: EndUser): | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| parser = reqparse.RequestParser() | |||
| @@ -117,7 +117,7 @@ class MessageSuggestedApi(Resource): | |||
| def get(self, app_model: App, end_user: EndUser, message_id): | |||
| message_id = str(message_id) | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| try: | |||
| @@ -41,7 +41,7 @@ class AppParameterApi(WebApiResource): | |||
| @marshal_with(parameters_fields) | |||
| def get(self, app_model: App, end_user): | |||
| """Retrieve app parameters.""" | |||
| if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: | |||
| if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: | |||
| workflow = app_model.workflow | |||
| if workflow is None: | |||
| raise AppUnavailableError() | |||
| @@ -78,7 +78,7 @@ class TextApi(WebApiResource): | |||
| message_id = args.get("message_id", None) | |||
| text = args.get("text", None) | |||
| if ( | |||
| app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] | |||
| app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} | |||
| and app_model.workflow | |||
| and app_model.workflow.features_dict | |||
| ): | |||
| @@ -87,7 +87,7 @@ class CompletionStopApi(WebApiResource): | |||
| class ChatApi(WebApiResource): | |||
| def post(self, app_model, end_user): | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| parser = reqparse.RequestParser() | |||
| @@ -136,7 +136,7 @@ class ChatApi(WebApiResource): | |||
| class ChatStopApi(WebApiResource): | |||
| def post(self, app_model, end_user, task_id): | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) | |||
| @@ -18,7 +18,7 @@ class ConversationListApi(WebApiResource): | |||
| @marshal_with(conversation_infinite_scroll_pagination_fields) | |||
| def get(self, app_model, end_user): | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| parser = reqparse.RequestParser() | |||
| @@ -56,7 +56,7 @@ class ConversationListApi(WebApiResource): | |||
| class ConversationApi(WebApiResource): | |||
| def delete(self, app_model, end_user, c_id): | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| conversation_id = str(c_id) | |||
| @@ -73,7 +73,7 @@ class ConversationRenameApi(WebApiResource): | |||
| @marshal_with(simple_conversation_fields) | |||
| def post(self, app_model, end_user, c_id): | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| conversation_id = str(c_id) | |||
| @@ -92,7 +92,7 @@ class ConversationRenameApi(WebApiResource): | |||
| class ConversationPinApi(WebApiResource): | |||
| def patch(self, app_model, end_user, c_id): | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| conversation_id = str(c_id) | |||
| @@ -108,7 +108,7 @@ class ConversationPinApi(WebApiResource): | |||
| class ConversationUnPinApi(WebApiResource): | |||
| def patch(self, app_model, end_user, c_id): | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| conversation_id = str(c_id) | |||
| @@ -78,7 +78,7 @@ class MessageListApi(WebApiResource): | |||
| @marshal_with(message_infinite_scroll_pagination_fields) | |||
| def get(self, app_model, end_user): | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| parser = reqparse.RequestParser() | |||
| @@ -160,7 +160,7 @@ class MessageMoreLikeThisApi(WebApiResource): | |||
| class MessageSuggestedQuestionApi(WebApiResource): | |||
| def get(self, app_model, end_user, message_id): | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotCompletionAppError() | |||
| message_id = str(message_id) | |||
| @@ -90,7 +90,7 @@ class CotAgentOutputParser: | |||
| if not in_code_block and not in_json: | |||
| if delta.lower() == action_str[action_idx] and action_idx == 0: | |||
| if last_character not in ["\n", " ", ""]: | |||
| if last_character not in {"\n", " ", ""}: | |||
| index += steps | |||
| yield delta | |||
| continue | |||
| @@ -117,7 +117,7 @@ class CotAgentOutputParser: | |||
| action_idx = 0 | |||
| if delta.lower() == thought_str[thought_idx] and thought_idx == 0: | |||
| if last_character not in ["\n", " ", ""]: | |||
| if last_character not in {"\n", " ", ""}: | |||
| index += steps | |||
| yield delta | |||
| continue | |||
| @@ -29,7 +29,7 @@ class BaseAppConfigManager: | |||
| additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict) | |||
| additional_features.file_upload = FileUploadConfigManager.convert( | |||
| config=config_dict, is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT] | |||
| config=config_dict, is_vision=app_mode in {AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT} | |||
| ) | |||
| additional_features.opening_statement, additional_features.suggested_questions = ( | |||
| @@ -18,7 +18,7 @@ class AgentConfigManager: | |||
| if agent_strategy == "function_call": | |||
| strategy = AgentEntity.Strategy.FUNCTION_CALLING | |||
| elif agent_strategy == "cot" or agent_strategy == "react": | |||
| elif agent_strategy in {"cot", "react"}: | |||
| strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT | |||
| else: | |||
| # old configs, try to detect default strategy | |||
| @@ -43,10 +43,10 @@ class AgentConfigManager: | |||
| agent_tools.append(AgentToolEntity(**agent_tool_properties)) | |||
| if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in [ | |||
| if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in { | |||
| "react_router", | |||
| "router", | |||
| ]: | |||
| }: | |||
| agent_prompt = agent_dict.get("prompt", None) or {} | |||
| # check model mode | |||
| model_mode = config.get("model", {}).get("mode", "completion") | |||
| @@ -167,7 +167,7 @@ class DatasetConfigManager: | |||
| config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value | |||
| has_datasets = False | |||
| if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]: | |||
| if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}: | |||
| for tool in config["agent_mode"]["tools"]: | |||
| key = list(tool.keys())[0] | |||
| if key == "dataset": | |||
| @@ -42,12 +42,12 @@ class BasicVariablesConfigManager: | |||
| variable=variable["variable"], type=variable["type"], config=variable["config"] | |||
| ) | |||
| ) | |||
| elif variable_type in [ | |||
| elif variable_type in { | |||
| VariableEntityType.TEXT_INPUT, | |||
| VariableEntityType.PARAGRAPH, | |||
| VariableEntityType.NUMBER, | |||
| VariableEntityType.SELECT, | |||
| ]: | |||
| }: | |||
| variable = variables[variable_type] | |||
| variable_entities.append( | |||
| VariableEntity( | |||
| @@ -97,7 +97,7 @@ class BasicVariablesConfigManager: | |||
| variables = [] | |||
| for item in config["user_input_form"]: | |||
| key = list(item.keys())[0] | |||
| if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]: | |||
| if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}: | |||
| raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") | |||
| form_item = item[key] | |||
| @@ -54,14 +54,14 @@ class FileUploadConfigManager: | |||
| if is_vision: | |||
| detail = config["file_upload"]["image"]["detail"] | |||
| if detail not in ["high", "low"]: | |||
| if detail not in {"high", "low"}: | |||
| raise ValueError("detail must be in ['high', 'low']") | |||
| transfer_methods = config["file_upload"]["image"]["transfer_methods"] | |||
| if not isinstance(transfer_methods, list): | |||
| raise ValueError("transfer_methods must be of list type") | |||
| for method in transfer_methods: | |||
| if method not in ["remote_url", "local_file"]: | |||
| if method not in {"remote_url", "local_file"}: | |||
| raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") | |||
| return config, ["file_upload"] | |||
| @@ -73,7 +73,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| raise ValueError("Workflow not initialized") | |||
| user_id = None | |||
| if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: | |||
| if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: | |||
| end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() | |||
| if end_user: | |||
| user_id = end_user.session_id | |||
| @@ -175,7 +175,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| user_id=self.application_generate_entity.user_id, | |||
| user_from=( | |||
| UserFrom.ACCOUNT | |||
| if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] | |||
| if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} | |||
| else UserFrom.END_USER | |||
| ), | |||
| invoke_from=self.application_generate_entity.invoke_from, | |||
| @@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC): | |||
| def convert( | |||
| cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom | |||
| ) -> dict[str, Any] | Generator[str, Any, None]: | |||
| if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: | |||
| if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}: | |||
| if isinstance(response, AppBlockingResponse): | |||
| return cls.convert_blocking_full_response(response) | |||
| else: | |||
| @@ -22,11 +22,11 @@ class BaseAppGenerator: | |||
| return var.default or "" | |||
| if ( | |||
| var.type | |||
| in ( | |||
| in { | |||
| VariableEntityType.TEXT_INPUT, | |||
| VariableEntityType.SELECT, | |||
| VariableEntityType.PARAGRAPH, | |||
| ) | |||
| } | |||
| and user_input_value | |||
| and not isinstance(user_input_value, str) | |||
| ): | |||
| @@ -44,7 +44,7 @@ class BaseAppGenerator: | |||
| options = var.options or [] | |||
| if user_input_value not in options: | |||
| raise ValueError(f"{var.variable} in input form must be one of the following: {options}") | |||
| elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): | |||
| elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}: | |||
| if var.max_length and user_input_value and len(user_input_value) > var.max_length: | |||
| raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters") | |||
| @@ -32,7 +32,7 @@ class AppQueueManager: | |||
| self._user_id = user_id | |||
| self._invoke_from = invoke_from | |||
| user_prefix = "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user" | |||
| user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" | |||
| redis_client.setex( | |||
| AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}" | |||
| ) | |||
| @@ -118,7 +118,7 @@ class AppQueueManager: | |||
| if result is None: | |||
| return | |||
| user_prefix = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user" | |||
| user_prefix = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" | |||
| if result.decode("utf-8") != f"{user_prefix}-{user_id}": | |||
| return | |||
| @@ -148,7 +148,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): | |||
| # get from source | |||
| end_user_id = None | |||
| account_id = None | |||
| if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: | |||
| if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: | |||
| from_source = "api" | |||
| end_user_id = application_generate_entity.user_id | |||
| else: | |||
| @@ -165,11 +165,11 @@ class MessageBasedAppGenerator(BaseAppGenerator): | |||
| model_provider = application_generate_entity.model_conf.provider | |||
| model_id = application_generate_entity.model_conf.model | |||
| override_model_configs = None | |||
| if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in [ | |||
| if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in { | |||
| AppMode.AGENT_CHAT, | |||
| AppMode.CHAT, | |||
| AppMode.COMPLETION, | |||
| ]: | |||
| }: | |||
| override_model_configs = app_config.app_model_config_dict | |||
| # get conversation introduction | |||
| @@ -53,7 +53,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| app_config = cast(WorkflowAppConfig, app_config) | |||
| user_id = None | |||
| if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: | |||
| if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: | |||
| end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() | |||
| if end_user: | |||
| user_id = end_user.session_id | |||
| @@ -113,7 +113,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| user_id=self.application_generate_entity.user_id, | |||
| user_from=( | |||
| UserFrom.ACCOUNT | |||
| if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] | |||
| if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} | |||
| else UserFrom.END_USER | |||
| ), | |||
| invoke_from=self.application_generate_entity.invoke_from, | |||
| @@ -63,7 +63,7 @@ class AnnotationReplyFeature: | |||
| score = documents[0].metadata["score"] | |||
| annotation = AppAnnotationService.get_annotation_by_id(annotation_id) | |||
| if annotation: | |||
| if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]: | |||
| if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}: | |||
| from_source = "api" | |||
| else: | |||
| from_source = "console" | |||
| @@ -372,7 +372,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| self._message, | |||
| application_generate_entity=self._application_generate_entity, | |||
| conversation=self._conversation, | |||
| is_first_message=self._application_generate_entity.app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT] | |||
| is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT} | |||
| and self._application_generate_entity.conversation_id is None, | |||
| extras=self._application_generate_entity.extras, | |||
| ) | |||
| @@ -383,7 +383,7 @@ class WorkflowCycleManage: | |||
| :param workflow_node_execution: workflow node execution | |||
| :return: | |||
| """ | |||
| if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: | |||
| if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: | |||
| return None | |||
| response = NodeStartStreamResponse( | |||
| @@ -430,7 +430,7 @@ class WorkflowCycleManage: | |||
| :param workflow_node_execution: workflow node execution | |||
| :return: | |||
| """ | |||
| if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: | |||
| if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: | |||
| return None | |||
| return NodeFinishStreamResponse( | |||
| @@ -29,7 +29,7 @@ class DatasetIndexToolCallbackHandler: | |||
| source="app", | |||
| source_app_id=self._app_id, | |||
| created_by_role=( | |||
| "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user" | |||
| "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" | |||
| ), | |||
| created_by=self._user_id, | |||
| ) | |||
| @@ -292,7 +292,7 @@ class IndexingRunner: | |||
| self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict | |||
| ) -> list[Document]: | |||
| # load file | |||
| if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]: | |||
| if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}: | |||
| return [] | |||
| data_source_info = dataset_document.data_source_info_dict | |||
| @@ -52,7 +52,7 @@ class TokenBufferMemory: | |||
| files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() | |||
| if files: | |||
| file_extra_config = None | |||
| if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: | |||
| if self.conversation.mode not in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: | |||
| file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) | |||
| else: | |||
| if message.workflow_run_id: | |||
| @@ -27,17 +27,17 @@ class ModelType(Enum): | |||
| :return: model type | |||
| """ | |||
| if origin_model_type == "text-generation" or origin_model_type == cls.LLM.value: | |||
| if origin_model_type in {"text-generation", cls.LLM.value}: | |||
| return cls.LLM | |||
| elif origin_model_type == "embeddings" or origin_model_type == cls.TEXT_EMBEDDING.value: | |||
| elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}: | |||
| return cls.TEXT_EMBEDDING | |||
| elif origin_model_type == "reranking" or origin_model_type == cls.RERANK.value: | |||
| elif origin_model_type in {"reranking", cls.RERANK.value}: | |||
| return cls.RERANK | |||
| elif origin_model_type == "speech2text" or origin_model_type == cls.SPEECH2TEXT.value: | |||
| elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}: | |||
| return cls.SPEECH2TEXT | |||
| elif origin_model_type == "tts" or origin_model_type == cls.TTS.value: | |||
| elif origin_model_type in {"tts", cls.TTS.value}: | |||
| return cls.TTS | |||
| elif origin_model_type == "text2img" or origin_model_type == cls.TEXT2IMG.value: | |||
| elif origin_model_type in {"text2img", cls.TEXT2IMG.value}: | |||
| return cls.TEXT2IMG | |||
| elif origin_model_type == cls.MODERATION.value: | |||
| return cls.MODERATION | |||
| @@ -494,7 +494,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): | |||
| mime_type = data_split[0].replace("data:", "") | |||
| base64_data = data_split[1] | |||
| if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: | |||
| if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}: | |||
| raise ValueError( | |||
| f"Unsupported image type {mime_type}, " | |||
| f"only support image/jpeg, image/png, image/gif, and image/webp" | |||
| @@ -85,14 +85,14 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): | |||
| for i in range(len(sentences)) | |||
| ] | |||
| for future in futures: | |||
| yield from future.result().__enter__().iter_bytes(1024) | |||
| yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801 | |||
| else: | |||
| response = client.audio.speech.with_streaming_response.create( | |||
| model=model, voice=voice, response_format="mp3", input=content_text.strip() | |||
| ) | |||
| yield from response.__enter__().iter_bytes(1024) | |||
| yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801 | |||
| except Exception as ex: | |||
| raise InvokeBadRequestError(str(ex)) | |||
| @@ -454,7 +454,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): | |||
| base64_data = data_split[1] | |||
| image_content = base64.b64decode(base64_data) | |||
| if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: | |||
| if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}: | |||
| raise ValueError( | |||
| f"Unsupported image type {mime_type}, " | |||
| f"only support image/jpeg, image/png, image/gif, and image/webp" | |||
| @@ -886,16 +886,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel): | |||
| if error_code == "AccessDeniedException": | |||
| return InvokeAuthorizationError(error_msg) | |||
| elif error_code in ["ResourceNotFoundException", "ValidationException"]: | |||
| elif error_code in {"ResourceNotFoundException", "ValidationException"}: | |||
| return InvokeBadRequestError(error_msg) | |||
| elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: | |||
| elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}: | |||
| return InvokeRateLimitError(error_msg) | |||
| elif error_code in [ | |||
| elif error_code in { | |||
| "ModelTimeoutException", | |||
| "ModelErrorException", | |||
| "InternalServerException", | |||
| "ModelNotReadyException", | |||
| ]: | |||
| }: | |||
| return InvokeServerUnavailableError(error_msg) | |||
| elif error_code == "ModelStreamErrorException": | |||
| return InvokeConnectionError(error_msg) | |||
| @@ -186,16 +186,16 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): | |||
| if error_code == "AccessDeniedException": | |||
| return InvokeAuthorizationError(error_msg) | |||
| elif error_code in ["ResourceNotFoundException", "ValidationException"]: | |||
| elif error_code in {"ResourceNotFoundException", "ValidationException"}: | |||
| return InvokeBadRequestError(error_msg) | |||
| elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: | |||
| elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}: | |||
| return InvokeRateLimitError(error_msg) | |||
| elif error_code in [ | |||
| elif error_code in { | |||
| "ModelTimeoutException", | |||
| "ModelErrorException", | |||
| "InternalServerException", | |||
| "ModelNotReadyException", | |||
| ]: | |||
| }: | |||
| return InvokeServerUnavailableError(error_msg) | |||
| elif error_code == "ModelStreamErrorException": | |||
| return InvokeConnectionError(error_msg) | |||
| @@ -6,10 +6,10 @@ from collections.abc import Generator | |||
| from typing import Optional, Union, cast | |||
| import google.ai.generativelanguage as glm | |||
| import google.api_core.exceptions as exceptions | |||
| import google.generativeai as genai | |||
| import google.generativeai.client as client | |||
| import requests | |||
| from google.api_core import exceptions | |||
| from google.generativeai import client | |||
| from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory | |||
| from google.generativeai.types.content_types import to_part | |||
| from PIL import Image | |||
| @@ -77,7 +77,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel | |||
| if "huggingfacehub_api_type" not in credentials: | |||
| raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.") | |||
| if credentials["huggingfacehub_api_type"] not in ("inference_endpoints", "hosted_inference_api"): | |||
| if credentials["huggingfacehub_api_type"] not in {"inference_endpoints", "hosted_inference_api"}: | |||
| raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.") | |||
| if "huggingfacehub_api_token" not in credentials: | |||
| @@ -94,7 +94,7 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel | |||
| credentials["huggingfacehub_api_token"], model | |||
| ) | |||
| if credentials["task_type"] not in ("text2text-generation", "text-generation"): | |||
| if credentials["task_type"] not in {"text2text-generation", "text-generation"}: | |||
| raise CredentialsValidateFailedError( | |||
| "Huggingface Hub Task Type must be one of text2text-generation, text-generation." | |||
| ) | |||
| @@ -75,7 +75,7 @@ class TeiHelper: | |||
| if len(model_type.keys()) < 1: | |||
| raise RuntimeError("model_type is empty") | |||
| model_type = list(model_type.keys())[0] | |||
| if model_type not in ["embedding", "reranker"]: | |||
| if model_type not in {"embedding", "reranker"}: | |||
| raise RuntimeError(f"invalid model_type: {model_type}") | |||
| max_input_length = response_json.get("max_input_length", 512) | |||
| @@ -100,9 +100,9 @@ class MinimaxChatCompletion: | |||
| return self._handle_chat_generate_response(response) | |||
| def _handle_error(self, code: int, msg: str): | |||
| if code == 1000 or code == 1001 or code == 1013 or code == 1027: | |||
| if code in {1000, 1001, 1013, 1027}: | |||
| raise InternalServerError(msg) | |||
| elif code == 1002 or code == 1039: | |||
| elif code in {1002, 1039}: | |||
| raise RateLimitReachedError(msg) | |||
| elif code == 1004: | |||
| raise InvalidAuthenticationError(msg) | |||
| @@ -105,9 +105,9 @@ class MinimaxChatCompletionPro: | |||
| return self._handle_chat_generate_response(response) | |||
| def _handle_error(self, code: int, msg: str): | |||
| if code == 1000 or code == 1001 or code == 1013 or code == 1027: | |||
| if code in {1000, 1001, 1013, 1027}: | |||
| raise InternalServerError(msg) | |||
| elif code == 1002 or code == 1039: | |||
| elif code in {1002, 1039}: | |||
| raise RateLimitReachedError(msg) | |||
| elif code == 1004: | |||
| raise InvalidAuthenticationError(msg) | |||
| @@ -114,7 +114,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): | |||
| raise CredentialsValidateFailedError("Invalid api key") | |||
| def _handle_error(self, code: int, msg: str): | |||
| if code == 1000 or code == 1001: | |||
| if code in {1000, 1001}: | |||
| raise InternalServerError(msg) | |||
| elif code == 1002: | |||
| raise RateLimitReachedError(msg) | |||
| @@ -125,7 +125,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): | |||
| model_mode = self.get_model_mode(base_model, credentials) | |||
| # transform response format | |||
| if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: | |||
| if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}: | |||
| stop = stop or [] | |||
| if model_mode == LLMMode.CHAT: | |||
| # chat model | |||
| @@ -89,14 +89,14 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): | |||
| for i in range(len(sentences)) | |||
| ] | |||
| for future in futures: | |||
| yield from future.result().__enter__().iter_bytes(1024) | |||
| yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801 | |||
| else: | |||
| response = client.audio.speech.with_streaming_response.create( | |||
| model=model, voice=voice, response_format="mp3", input=content_text.strip() | |||
| ) | |||
| yield from response.__enter__().iter_bytes(1024) | |||
| yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801 | |||
| except Exception as ex: | |||
| raise InvokeBadRequestError(str(ex)) | |||
| @@ -12,7 +12,6 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel): | |||
| credentials["endpoint_url"] = "https://openrouter.ai/api/v1" | |||
| credentials["mode"] = self.get_model_mode(model).value | |||
| credentials["function_calling_type"] = "tool_call" | |||
| return | |||
| def _invoke( | |||
| self, | |||
| @@ -154,7 +154,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): | |||
| ) | |||
| for key, value in input_properties: | |||
| if key not in ["system_prompt", "prompt"] and "stop" not in key: | |||
| if key not in {"system_prompt", "prompt"} and "stop" not in key: | |||
| value_type = value.get("type") | |||
| if not value_type: | |||
| @@ -86,7 +86,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): | |||
| ) | |||
| for input_property in input_properties: | |||
| if input_property[0] in ("text", "texts", "inputs"): | |||
| if input_property[0] in {"text", "texts", "inputs"}: | |||
| text_input_key = input_property[0] | |||
| return text_input_key | |||
| @@ -96,7 +96,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): | |||
| def _generate_embeddings_by_text_input_key( | |||
| client: ReplicateClient, replicate_model_version: str, text_input_key: str, texts: list[str] | |||
| ) -> list[list[float]]: | |||
| if text_input_key in ("text", "inputs"): | |||
| if text_input_key in {"text", "inputs"}: | |||
| embeddings = [] | |||
| for text in texts: | |||
| result = client.run(replicate_model_version, input={text_input_key: text}) | |||
| @@ -89,7 +89,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): | |||
| :param tools: tools for tool calling | |||
| :return: | |||
| """ | |||
| if model in ["qwen-turbo-chat", "qwen-plus-chat"]: | |||
| if model in {"qwen-turbo-chat", "qwen-plus-chat"}: | |||
| model = model.replace("-chat", "") | |||
| if model == "farui-plus": | |||
| model = "qwen-farui-plus" | |||
| @@ -157,7 +157,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): | |||
| mode = self.get_model_mode(model, credentials) | |||
| if model in ["qwen-turbo-chat", "qwen-plus-chat"]: | |||
| if model in {"qwen-turbo-chat", "qwen-plus-chat"}: | |||
| model = model.replace("-chat", "") | |||
| extra_model_kwargs = {} | |||
| @@ -201,7 +201,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): | |||
| :param prompt_messages: prompt messages | |||
| :return: llm response | |||
| """ | |||
| if response.status_code != 200 and response.status_code != HTTPStatus.OK: | |||
| if response.status_code not in {200, HTTPStatus.OK}: | |||
| raise ServiceUnavailableError(response.message) | |||
| # transform assistant message to prompt message | |||
| assistant_prompt_message = AssistantPromptMessage( | |||
| @@ -240,7 +240,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): | |||
| full_text = "" | |||
| tool_calls = [] | |||
| for index, response in enumerate(responses): | |||
| if response.status_code != 200 and response.status_code != HTTPStatus.OK: | |||
| if response.status_code not in {200, HTTPStatus.OK}: | |||
| raise ServiceUnavailableError( | |||
| f"Failed to invoke model {model}, status code: {response.status_code}, " | |||
| f"message: {response.message}" | |||
| @@ -93,7 +93,7 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): | |||
| """ | |||
| Code block mode wrapper for invoking large language model | |||
| """ | |||
| if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: | |||
| if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}: | |||
| stop = stop or [] | |||
| self._transform_chat_json_prompts( | |||
| model=model, | |||
| @@ -5,7 +5,6 @@ import logging | |||
| from collections.abc import Generator | |||
| from typing import Optional, Union, cast | |||
| import google.api_core.exceptions as exceptions | |||
| import google.auth.transport.requests | |||
| import vertexai.generative_models as glm | |||
| from anthropic import AnthropicVertex, Stream | |||
| @@ -17,6 +16,7 @@ from anthropic.types import ( | |||
| MessageStopEvent, | |||
| MessageStreamEvent, | |||
| ) | |||
| from google.api_core import exceptions | |||
| from google.cloud import aiplatform | |||
| from google.oauth2 import service_account | |||
| from PIL import Image | |||
| @@ -346,7 +346,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): | |||
| mime_type = data_split[0].replace("data:", "") | |||
| base64_data = data_split[1] | |||
| if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: | |||
| if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}: | |||
| raise ValueError( | |||
| f"Unsupported image type {mime_type}, " | |||
| f"only support image/jpeg, image/png, image/gif, and image/webp" | |||
| @@ -96,7 +96,6 @@ class Signer: | |||
| signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service) | |||
| sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str)) | |||
| request.headers["Authorization"] = Signer.build_auth_header_v4(sign, md, credentials) | |||
| return | |||
| @staticmethod | |||
| def hashed_canonical_request_v4(request, meta): | |||
| @@ -105,7 +104,7 @@ class Signer: | |||
| signed_headers = {} | |||
| for key in request.headers: | |||
| if key in ["Content-Type", "Content-Md5", "Host"] or key.startswith("X-"): | |||
| if key in {"Content-Type", "Content-Md5", "Host"} or key.startswith("X-"): | |||
| signed_headers[key.lower()] = request.headers[key] | |||
| if "host" in signed_headers: | |||
| @@ -69,7 +69,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): | |||
| """ | |||
| Code block mode wrapper for invoking large language model | |||
| """ | |||
| if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: | |||
| if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}: | |||
| response_format = model_parameters["response_format"] | |||
| stop = stop or [] | |||
| self._transform_json_prompts( | |||
| @@ -103,7 +103,7 @@ class XinferenceHelper: | |||
| model_handle_type = "embedding" | |||
| elif response_json.get("model_type") == "audio": | |||
| model_handle_type = "audio" | |||
| if model_family and model_family in ["ChatTTS", "CosyVoice", "FishAudio"]: | |||
| if model_family and model_family in {"ChatTTS", "CosyVoice", "FishAudio"}: | |||
| model_ability.append("text-to-audio") | |||
| else: | |||
| model_ability.append("audio-to-text") | |||
| @@ -186,10 +186,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| new_prompt_messages: list[PromptMessage] = [] | |||
| for prompt_message in prompt_messages: | |||
| copy_prompt_message = prompt_message.copy() | |||
| if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]: | |||
| if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL}: | |||
| if isinstance(copy_prompt_message.content, list): | |||
| # check if model is 'glm-4v' | |||
| if model not in ("glm-4v", "glm-4v-plus"): | |||
| if model not in {"glm-4v", "glm-4v-plus"}: | |||
| # not support list message | |||
| continue | |||
| # get image and | |||
| @@ -209,10 +209,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| ): | |||
| new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content | |||
| else: | |||
| if ( | |||
| copy_prompt_message.role == PromptMessageRole.USER | |||
| or copy_prompt_message.role == PromptMessageRole.TOOL | |||
| ): | |||
| if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.TOOL}: | |||
| new_prompt_messages.append(copy_prompt_message) | |||
| elif copy_prompt_message.role == PromptMessageRole.SYSTEM: | |||
| new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content) | |||
| @@ -226,7 +223,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| else: | |||
| new_prompt_messages.append(copy_prompt_message) | |||
| if model == "glm-4v" or model == "glm-4v-plus": | |||
| if model in {"glm-4v", "glm-4v-plus"}: | |||
| params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters) | |||
| else: | |||
| params = {"model": model, "messages": [], **model_parameters} | |||
| @@ -270,11 +267,11 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| # chatglm model | |||
| for prompt_message in new_prompt_messages: | |||
| # merge system message to user message | |||
| if ( | |||
| prompt_message.role == PromptMessageRole.SYSTEM | |||
| or prompt_message.role == PromptMessageRole.TOOL | |||
| or prompt_message.role == PromptMessageRole.USER | |||
| ): | |||
| if prompt_message.role in { | |||
| PromptMessageRole.SYSTEM, | |||
| PromptMessageRole.TOOL, | |||
| PromptMessageRole.USER, | |||
| }: | |||
| if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user": | |||
| params["messages"][-1]["content"] += "\n\n" + prompt_message.content | |||
| else: | |||
| @@ -1,5 +1,4 @@ | |||
| from __future__ import annotations | |||
| from .fine_tuning_job import FineTuningJob as FineTuningJob | |||
| from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob | |||
| from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent | |||
| from .fine_tuning_job import FineTuningJob, ListOfFineTuningJob | |||
| from .fine_tuning_job_event import FineTuningJobEvent | |||
| @@ -75,7 +75,7 @@ class CommonValidator: | |||
| if not isinstance(value, str): | |||
| raise ValueError(f"Variable {credential_form_schema.variable} should be string") | |||
| if credential_form_schema.type in [FormType.SELECT, FormType.RADIO]: | |||
| if credential_form_schema.type in {FormType.SELECT, FormType.RADIO}: | |||
| # If the value is in options, no validation is performed | |||
| if credential_form_schema.options: | |||
| if value not in [option.value for option in credential_form_schema.options]: | |||
| @@ -83,7 +83,7 @@ class CommonValidator: | |||
| if credential_form_schema.type == FormType.SWITCH: | |||
| # If the value is not in ['true', 'false'], an exception is thrown | |||
| if value.lower() not in ["true", "false"]: | |||
| if value.lower() not in {"true", "false"}: | |||
| raise ValueError(f"Variable {credential_form_schema.variable} should be true or false") | |||
| value = True if value.lower() == "true" else False | |||
| @@ -51,7 +51,7 @@ class ElasticSearchVector(BaseVector): | |||
| def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: | |||
| try: | |||
| parsed_url = urlparse(config.host) | |||
| if parsed_url.scheme in ["http", "https"]: | |||
| if parsed_url.scheme in {"http", "https"}: | |||
| hosts = f"{config.host}:{config.port}" | |||
| else: | |||
| hosts = f"http://{config.host}:{config.port}" | |||
| @@ -94,7 +94,7 @@ class ElasticSearchVector(BaseVector): | |||
| return uuids | |||
| def text_exists(self, id: str) -> bool: | |||
| return self._client.exists(index=self._collection_name, id=id).__bool__() | |||
| return bool(self._client.exists(index=self._collection_name, id=id)) | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| for id in ids: | |||
| @@ -35,7 +35,7 @@ class MyScaleVector(BaseVector): | |||
| super().__init__(collection_name) | |||
| self._config = config | |||
| self._metric = metric | |||
| self._vec_order = SortOrder.ASC if metric.upper() in ["COSINE", "L2"] else SortOrder.DESC | |||
| self._vec_order = SortOrder.ASC if metric.upper() in {"COSINE", "L2"} else SortOrder.DESC | |||
| self._client = get_client( | |||
| host=config.host, | |||
| port=config.port, | |||
| @@ -92,7 +92,7 @@ class MyScaleVector(BaseVector): | |||
| @staticmethod | |||
| def escape_str(value: Any) -> str: | |||
| return "".join(" " if c in ("\\", "'") else c for c in str(value)) | |||
| return "".join(" " if c in {"\\", "'"} else c for c in str(value)) | |||
| def text_exists(self, id: str) -> bool: | |||
| results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'") | |||
| @@ -223,15 +223,7 @@ class OracleVector(BaseVector): | |||
| words = pseg.cut(query) | |||
| current_entity = "" | |||
| for word, pos in words: | |||
| if ( | |||
| pos == "nr" | |||
| or pos == "Ng" | |||
| or pos == "eng" | |||
| or pos == "nz" | |||
| or pos == "n" | |||
| or pos == "ORG" | |||
| or pos == "v" | |||
| ): # nr: 人名, ns: 地名, nt: 机构名 | |||
| if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名 | |||
| current_entity += word | |||
| else: | |||
| if current_entity: | |||
| @@ -98,17 +98,17 @@ class ExtractProcessor: | |||
| unstructured_api_url = dify_config.UNSTRUCTURED_API_URL | |||
| unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY | |||
| if etl_type == "Unstructured": | |||
| if file_extension == ".xlsx" or file_extension == ".xls": | |||
| if file_extension in {".xlsx", ".xls"}: | |||
| extractor = ExcelExtractor(file_path) | |||
| elif file_extension == ".pdf": | |||
| extractor = PdfExtractor(file_path) | |||
| elif file_extension in [".md", ".markdown"]: | |||
| elif file_extension in {".md", ".markdown"}: | |||
| extractor = ( | |||
| UnstructuredMarkdownExtractor(file_path, unstructured_api_url) | |||
| if is_automatic | |||
| else MarkdownExtractor(file_path, autodetect_encoding=True) | |||
| ) | |||
| elif file_extension in [".htm", ".html"]: | |||
| elif file_extension in {".htm", ".html"}: | |||
| extractor = HtmlExtractor(file_path) | |||
| elif file_extension == ".docx": | |||
| extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) | |||
| @@ -134,13 +134,13 @@ class ExtractProcessor: | |||
| else TextExtractor(file_path, autodetect_encoding=True) | |||
| ) | |||
| else: | |||
| if file_extension == ".xlsx" or file_extension == ".xls": | |||
| if file_extension in {".xlsx", ".xls"}: | |||
| extractor = ExcelExtractor(file_path) | |||
| elif file_extension == ".pdf": | |||
| extractor = PdfExtractor(file_path) | |||
| elif file_extension in [".md", ".markdown"]: | |||
| elif file_extension in {".md", ".markdown"}: | |||
| extractor = MarkdownExtractor(file_path, autodetect_encoding=True) | |||
| elif file_extension in [".htm", ".html"]: | |||
| elif file_extension in {".htm", ".html"}: | |||
| extractor = HtmlExtractor(file_path) | |||
| elif file_extension == ".docx": | |||
| extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) | |||
| @@ -32,7 +32,7 @@ class FirecrawlApp: | |||
| else: | |||
| raise Exception(f'Failed to scrape URL. Error: {response["error"]}') | |||
| elif response.status_code in [402, 409, 500]: | |||
| elif response.status_code in {402, 409, 500}: | |||
| error_message = response.json().get("error", "Unknown error occurred") | |||
| raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}") | |||
| else: | |||
| @@ -103,12 +103,12 @@ class NotionExtractor(BaseExtractor): | |||
| multi_select_list = property_value[type] | |||
| for multi_select in multi_select_list: | |||
| value.append(multi_select["name"]) | |||
| elif type == "rich_text" or type == "title": | |||
| elif type in {"rich_text", "title"}: | |||
| if len(property_value[type]) > 0: | |||
| value = property_value[type][0]["plain_text"] | |||
| else: | |||
| value = "" | |||
| elif type == "select" or type == "status": | |||
| elif type in {"select", "status"}: | |||
| if property_value[type]: | |||
| value = property_value[type]["name"] | |||
| else: | |||
| @@ -115,7 +115,7 @@ class DatasetRetrieval: | |||
| available_datasets.append(dataset) | |||
| all_documents = [] | |||
| user_from = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user" | |||
| user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" | |||
| if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: | |||
| all_documents = self.single_retrieve( | |||
| app_id, | |||
| @@ -35,7 +35,7 @@ def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> l | |||
| splits = re.split(separator, text) | |||
| else: | |||
| splits = list(text) | |||
| return [s for s in splits if (s != "" and s != "\n")] | |||
| return [s for s in splits if (s not in {"", "\n"})] | |||
| class TextSplitter(BaseDocumentTransformer, ABC): | |||
| @@ -68,7 +68,7 @@ class AppToolProviderEntity(ToolProviderController): | |||
| label = input_form[form_type]["label"] | |||
| variable_name = input_form[form_type]["variable_name"] | |||
| options = input_form[form_type].get("options", []) | |||
| if form_type == "paragraph" or form_type == "text-input": | |||
| if form_type in {"paragraph", "text-input"}: | |||
| tool["parameters"].append( | |||
| ToolParameter( | |||
| name=variable_name, | |||
| @@ -168,7 +168,7 @@ class AIPPTGenerateTool(BuiltinTool): | |||
| pass | |||
| elif event == "close": | |||
| break | |||
| elif event == "error" or event == "filter": | |||
| elif event in {"error", "filter"}: | |||
| raise Exception(f"Failed to generate outline: {data}") | |||
| return outline | |||
| @@ -213,7 +213,7 @@ class AIPPTGenerateTool(BuiltinTool): | |||
| pass | |||
| elif event == "close": | |||
| break | |||
| elif event == "error" or event == "filter": | |||
| elif event in {"error", "filter"}: | |||
| raise Exception(f"Failed to generate content: {data}") | |||
| return content | |||
| @@ -39,11 +39,11 @@ class DallE3Tool(BuiltinTool): | |||
| n = tool_parameters.get("n", 1) | |||
| # get quality | |||
| quality = tool_parameters.get("quality", "standard") | |||
| if quality not in ["standard", "hd"]: | |||
| if quality not in {"standard", "hd"}: | |||
| return self.create_text_message("Invalid quality") | |||
| # get style | |||
| style = tool_parameters.get("style", "vivid") | |||
| if style not in ["natural", "vivid"]: | |||
| if style not in {"natural", "vivid"}: | |||
| return self.create_text_message("Invalid style") | |||
| # set extra body | |||
| seed_id = tool_parameters.get("seed_id", self._generate_random_id(8)) | |||
| @@ -14,7 +14,7 @@ class SimpleCode(BuiltinTool): | |||
| language = tool_parameters.get("language", CodeLanguage.PYTHON3) | |||
| code = tool_parameters.get("code", "") | |||
| if language not in [CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]: | |||
| if language not in {CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT}: | |||
| raise ValueError(f"Only python3 and javascript are supported, not {language}") | |||
| result = CodeExecutor.execute_code(language, "", code) | |||
| @@ -34,11 +34,11 @@ class CogView3Tool(BuiltinTool): | |||
| n = tool_parameters.get("n", 1) | |||
| # get quality | |||
| quality = tool_parameters.get("quality", "standard") | |||
| if quality not in ["standard", "hd"]: | |||
| if quality not in {"standard", "hd"}: | |||
| return self.create_text_message("Invalid quality") | |||
| # get style | |||
| style = tool_parameters.get("style", "vivid") | |||
| if style not in ["natural", "vivid"]: | |||
| if style not in {"natural", "vivid"}: | |||
| return self.create_text_message("Invalid style") | |||
| # set extra body | |||
| seed_id = tool_parameters.get("seed_id", self._generate_random_id(8)) | |||
| @@ -49,11 +49,11 @@ class DallE3Tool(BuiltinTool): | |||
| n = tool_parameters.get("n", 1) | |||
| # get quality | |||
| quality = tool_parameters.get("quality", "standard") | |||
| if quality not in ["standard", "hd"]: | |||
| if quality not in {"standard", "hd"}: | |||
| return self.create_text_message("Invalid quality") | |||
| # get style | |||
| style = tool_parameters.get("style", "vivid") | |||
| if style not in ["natural", "vivid"]: | |||
| if style not in {"natural", "vivid"}: | |||
| return self.create_text_message("Invalid style") | |||
| # call openapi dalle3 | |||
| @@ -133,9 +133,9 @@ class GetWorksheetFieldsTool(BuiltinTool): | |||
| def _extract_options(self, control: dict) -> list: | |||
| options = [] | |||
| if control["type"] in [9, 10, 11]: | |||
| if control["type"] in {9, 10, 11}: | |||
| options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])]) | |||
| elif control["type"] in [28, 36]: | |||
| elif control["type"] in {28, 36}: | |||
| itemnames = control["advancedSetting"].get("itemnames") | |||
| if itemnames and itemnames.startswith("[{"): | |||
| try: | |||
| @@ -183,11 +183,11 @@ class ListWorksheetRecordsTool(BuiltinTool): | |||
| type_id = field.get("typeId") | |||
| if type_id == 10: | |||
| value = value if isinstance(value, str) else "、".join(value) | |||
| elif type_id in [28, 36]: | |||
| elif type_id in {28, 36}: | |||
| value = field.get("options", {}).get(value, value) | |||
| elif type_id in [26, 27, 48, 14]: | |||
| elif type_id in {26, 27, 48, 14}: | |||
| value = self.process_value(value) | |||
| elif type_id in [35, 29]: | |||
| elif type_id in {35, 29}: | |||
| value = self.parse_cascade_or_associated(field, value) | |||
| elif type_id == 40: | |||
| value = self.parse_location(value) | |||
| @@ -35,7 +35,7 @@ class NovitaAiModelQueryTool(BuiltinTool): | |||
| models_data=[], | |||
| headers=headers, | |||
| params=params, | |||
| recursive=not (result_type == "first sd_name" or result_type == "first name sd_name pair"), | |||
| recursive=result_type not in {"first sd_name", "first name sd_name pair"}, | |||
| ) | |||
| result_str = "" | |||
| @@ -38,7 +38,7 @@ class SearchAPI: | |||
| return { | |||
| "engine": "google", | |||
| "q": query, | |||
| **{key: value for key, value in kwargs.items() if value not in [None, ""]}, | |||
| **{key: value for key, value in kwargs.items() if value not in {None, ""}}, | |||
| } | |||
| @staticmethod | |||
| @@ -38,7 +38,7 @@ class SearchAPI: | |||
| return { | |||
| "engine": "google_jobs", | |||
| "q": query, | |||
| **{key: value for key, value in kwargs.items() if value not in [None, ""]}, | |||
| **{key: value for key, value in kwargs.items() if value not in {None, ""}}, | |||
| } | |||
| @staticmethod | |||
| @@ -38,7 +38,7 @@ class SearchAPI: | |||
| return { | |||
| "engine": "google_news", | |||
| "q": query, | |||
| **{key: value for key, value in kwargs.items() if value not in [None, ""]}, | |||
| **{key: value for key, value in kwargs.items() if value not in {None, ""}}, | |||
| } | |||
| @staticmethod | |||
| @@ -38,7 +38,7 @@ class SearchAPI: | |||
| "engine": "youtube_transcripts", | |||
| "video_id": video_id, | |||
| "lang": language or "en", | |||
| **{key: value for key, value in kwargs.items() if value not in [None, ""]}, | |||
| **{key: value for key, value in kwargs.items() if value not in {None, ""}}, | |||
| } | |||
| @staticmethod | |||
| @@ -214,7 +214,7 @@ class Spider: | |||
| return requests.delete(url, headers=headers, stream=stream) | |||
| def _handle_error(self, response, action): | |||
| if response.status_code in [402, 409, 500]: | |||
| if response.status_code in {402, 409, 500}: | |||
| error_message = response.json().get("error", "Unknown error occurred") | |||
| raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") | |||
| else: | |||
| @@ -32,7 +32,7 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization): | |||
| model = tool_parameters.get("model", "core") | |||
| if model in ["sd3", "sd3-turbo"]: | |||
| if model in {"sd3", "sd3-turbo"}: | |||
| payload["model"] = tool_parameters.get("model") | |||
| if model != "sd3-turbo": | |||
| @@ -38,7 +38,7 @@ class VannaTool(BuiltinTool): | |||
| vn = VannaDefault(model=model, api_key=api_key) | |||
| db_type = tool_parameters.get("db_type", "") | |||
| if db_type in ["Postgres", "MySQL", "Hive", "ClickHouse"]: | |||
| if db_type in {"Postgres", "MySQL", "Hive", "ClickHouse"}: | |||
| if not db_name: | |||
| return self.create_text_message("Please input database name") | |||
| if not username: | |||
| @@ -19,7 +19,7 @@ from core.tools.utils.yaml_utils import load_yaml_file | |||
| class BuiltinToolProviderController(ToolProviderController): | |||
| def __init__(self, **data: Any) -> None: | |||
| if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP: | |||
| if self.provider_type in {ToolProviderType.API, ToolProviderType.APP}: | |||
| super().__init__(**data) | |||
| return | |||
| @@ -153,10 +153,10 @@ class ToolProviderController(BaseModel, ABC): | |||
| # check type | |||
| credential_schema = credentials_need_to_validate[credential_name] | |||
| if ( | |||
| credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT | |||
| or credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT | |||
| ): | |||
| if credential_schema in { | |||
| ToolProviderCredentials.CredentialsType.SECRET_INPUT, | |||
| ToolProviderCredentials.CredentialsType.TEXT_INPUT, | |||
| }: | |||
| if not isinstance(credentials[credential_name], str): | |||
| raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") | |||
| @@ -184,11 +184,11 @@ class ToolProviderController(BaseModel, ABC): | |||
| if credential_schema.default is not None: | |||
| default_value = credential_schema.default | |||
| # parse default value into the correct type | |||
| if ( | |||
| credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT | |||
| or credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT | |||
| or credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT | |||
| ): | |||
| if credential_schema.type in { | |||
| ToolProviderCredentials.CredentialsType.SECRET_INPUT, | |||
| ToolProviderCredentials.CredentialsType.TEXT_INPUT, | |||
| ToolProviderCredentials.CredentialsType.SELECT, | |||
| }: | |||
| default_value = str(default_value) | |||
| credentials[credential_name] = default_value | |||
| @@ -5,7 +5,7 @@ from urllib.parse import urlencode | |||
| import httpx | |||
| import core.helper.ssrf_proxy as ssrf_proxy | |||
| from core.helper import ssrf_proxy | |||
| from core.tools.entities.tool_bundle import ApiToolBundle | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType | |||
| from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError | |||
| @@ -191,7 +191,7 @@ class ApiTool(Tool): | |||
| else: | |||
| body = body | |||
| if method in ("get", "head", "post", "put", "delete", "patch"): | |||
| if method in {"get", "head", "post", "put", "delete", "patch"}: | |||
| response = getattr(ssrf_proxy, method)( | |||
| url, | |||
| params=params, | |||
| @@ -224,9 +224,9 @@ class ApiTool(Tool): | |||
| elif option["type"] == "string": | |||
| return str(value) | |||
| elif option["type"] == "boolean": | |||
| if str(value).lower() in ["true", "1"]: | |||
| if str(value).lower() in {"true", "1"}: | |||
| return True | |||
| elif str(value).lower() in ["false", "0"]: | |||
| elif str(value).lower() in {"false", "0"}: | |||
| return False | |||
| else: | |||
| continue # Not a boolean, try next option | |||
| @@ -189,10 +189,7 @@ class ToolEngine: | |||
| result += response.message | |||
| elif response.type == ToolInvokeMessage.MessageType.LINK: | |||
| result += f"result link: {response.message}. please tell user to check it." | |||
| elif ( | |||
| response.type == ToolInvokeMessage.MessageType.IMAGE_LINK | |||
| or response.type == ToolInvokeMessage.MessageType.IMAGE | |||
| ): | |||
| elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: | |||
| result += ( | |||
| "image has been created and sent to user already, you do not need to create it," | |||
| " just tell the user to check it now." | |||
| @@ -212,10 +209,7 @@ class ToolEngine: | |||
| result = [] | |||
| for response in tool_response: | |||
| if ( | |||
| response.type == ToolInvokeMessage.MessageType.IMAGE_LINK | |||
| or response.type == ToolInvokeMessage.MessageType.IMAGE | |||
| ): | |||
| if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: | |||
| mimetype = None | |||
| if response.meta.get("mime_type"): | |||
| mimetype = response.meta.get("mime_type") | |||
| @@ -297,7 +291,7 @@ class ToolEngine: | |||
| belongs_to="assistant", | |||
| url=message.url, | |||
| upload_file_id=None, | |||
| created_by_role=("account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"), | |||
| created_by_role=("account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"), | |||
| created_by=user_id, | |||
| ) | |||
| @@ -19,7 +19,7 @@ class ToolFileMessageTransformer: | |||
| result = [] | |||
| for message in messages: | |||
| if message.type == ToolInvokeMessage.MessageType.TEXT or message.type == ToolInvokeMessage.MessageType.LINK: | |||
| if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}: | |||
| result.append(message) | |||
| elif message.type == ToolInvokeMessage.MessageType.IMAGE: | |||
| # try to download image | |||
| @@ -165,7 +165,7 @@ class ApiBasedToolSchemaParser: | |||
| elif "schema" in parameter and "type" in parameter["schema"]: | |||
| typ = parameter["schema"]["type"] | |||
| if typ == "integer" or typ == "number": | |||
| if typ in {"integer", "number"}: | |||
| return ToolParameter.ToolParameterType.NUMBER | |||
| elif typ == "boolean": | |||
| return ToolParameter.ToolParameterType.BOOLEAN | |||
| @@ -313,7 +313,7 @@ def normalize_whitespace(text): | |||
| def is_leaf(element): | |||
| return element.name in ["p", "li"] | |||
| return element.name in {"p", "li"} | |||
| def is_text(element): | |||
| @@ -51,7 +51,7 @@ class RouteNodeState(BaseModel): | |||
| :param run_result: run result | |||
| """ | |||
| if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]: | |||
| if self.status in {RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED}: | |||
| raise Exception(f"Route state {self.id} already finished") | |||
| if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: | |||
| @@ -148,11 +148,11 @@ class AnswerStreamGeneratorRouter: | |||
| for edge in reverse_edges: | |||
| source_node_id = edge.source_node_id | |||
| source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") | |||
| if source_node_type in ( | |||
| if source_node_type in { | |||
| NodeType.ANSWER.value, | |||
| NodeType.IF_ELSE.value, | |||
| NodeType.QUESTION_CLASSIFIER.value, | |||
| ): | |||
| }: | |||
| answer_dependencies[answer_node_id].append(source_node_id) | |||
| else: | |||
| cls._recursive_fetch_answer_dependencies( | |||
| @@ -136,10 +136,10 @@ class EndStreamGeneratorRouter: | |||
| for edge in reverse_edges: | |||
| source_node_id = edge.source_node_id | |||
| source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") | |||
| if source_node_type in ( | |||
| if source_node_type in { | |||
| NodeType.IF_ELSE.value, | |||
| NodeType.QUESTION_CLASSIFIER, | |||
| ): | |||
| }: | |||
| end_dependencies[end_node_id].append(source_node_id) | |||
| else: | |||
| cls._recursive_fetch_end_dependencies( | |||