Co-authored-by: -LAN- <laipz8200@outlook.com>tags/0.6.12
| @@ -40,7 +40,7 @@ class AgentConfigManager: | |||
| 'provider_type': tool['provider_type'], | |||
| 'provider_id': tool['provider_id'], | |||
| 'tool_name': tool['tool_name'], | |||
| 'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {} | |||
| 'tool_parameters': tool.get('tool_parameters', {}) | |||
| } | |||
| agent_tools.append(AgentToolEntity(**agent_tool_properties)) | |||
| @@ -59,7 +59,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| inputs = args['inputs'] | |||
| extras = { | |||
| "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else False | |||
| "auto_generate_conversation_name": args.get('auto_generate_name', False) | |||
| } | |||
| # get conversation | |||
| @@ -57,7 +57,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): | |||
| inputs = args['inputs'] | |||
| extras = { | |||
| "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True | |||
| "auto_generate_conversation_name": args.get('auto_generate_name', True) | |||
| } | |||
| # get conversation | |||
| @@ -203,7 +203,7 @@ class AgentChatAppRunner(AppRunner): | |||
| llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) | |||
| model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) | |||
| if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []): | |||
| if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): | |||
| agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING | |||
| conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() | |||
| @@ -55,7 +55,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): | |||
| inputs = args['inputs'] | |||
| extras = { | |||
| "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True | |||
| "auto_generate_conversation_name": args.get('auto_generate_name', True) | |||
| } | |||
| # get conversation | |||
| @@ -66,8 +66,8 @@ class ProviderConfiguration(BaseModel): | |||
| original_provider_configurate_methods[self.provider.provider].append(configurate_method) | |||
| if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: | |||
| if (any([len(quota_configuration.restrict_models) > 0 | |||
| for quota_configuration in self.system_configuration.quota_configurations]) | |||
| if (any(len(quota_configuration.restrict_models) > 0 | |||
| for quota_configuration in self.system_configuration.quota_configurations) | |||
| and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods): | |||
| self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) | |||
| @@ -397,7 +397,7 @@ class IndexingRunner: | |||
| document_id=dataset_document.id, | |||
| after_indexing_status="splitting", | |||
| extra_update_params={ | |||
| DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]), | |||
| DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs), | |||
| DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||
| } | |||
| ) | |||
| @@ -83,7 +83,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): | |||
| max_workers = self._get_model_workers_limit(model, credentials) | |||
| try: | |||
| sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) | |||
| audio_bytes_list = list() | |||
| audio_bytes_list = [] | |||
| # Create a thread pool and map the function to the list of sentences | |||
| with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: | |||
| @@ -175,8 +175,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel): | |||
| # - https://docs.anthropic.com/claude/reference/claude-on-amazon-bedrock | |||
| # - https://github.com/anthropics/anthropic-sdk-python | |||
| client = AnthropicBedrock( | |||
| aws_access_key=credentials.get("aws_access_key_id", None), | |||
| aws_secret_key=credentials.get("aws_secret_access_key", None), | |||
| aws_access_key=credentials.get("aws_access_key_id"), | |||
| aws_secret_key=credentials.get("aws_secret_access_key"), | |||
| aws_region=credentials["aws_region"], | |||
| ) | |||
| @@ -576,7 +576,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): | |||
| """ | |||
| Create payload for bedrock api call depending on model provider | |||
| """ | |||
| payload = dict() | |||
| payload = {} | |||
| model_prefix = model.split('.')[0] | |||
| model_name = model.split('.')[1] | |||
| @@ -648,8 +648,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel): | |||
| runtime_client = boto3.client( | |||
| service_name='bedrock-runtime', | |||
| config=client_config, | |||
| aws_access_key_id=credentials.get("aws_access_key_id", None), | |||
| aws_secret_access_key=credentials.get("aws_secret_access_key", None) | |||
| aws_access_key_id=credentials.get("aws_access_key_id"), | |||
| aws_secret_access_key=credentials.get("aws_secret_access_key") | |||
| ) | |||
| model_prefix = model.split('.')[0] | |||
| @@ -49,8 +49,8 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): | |||
| bedrock_runtime = boto3.client( | |||
| service_name='bedrock-runtime', | |||
| config=client_config, | |||
| aws_access_key_id=credentials.get("aws_access_key_id", None), | |||
| aws_secret_access_key=credentials.get("aws_secret_access_key", None) | |||
| aws_access_key_id=credentials.get("aws_access_key_id"), | |||
| aws_secret_access_key=credentials.get("aws_secret_access_key") | |||
| ) | |||
| embeddings = [] | |||
| @@ -148,7 +148,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): | |||
| """ | |||
| Create payload for bedrock api call depending on model provider | |||
| """ | |||
| payload = dict() | |||
| payload = {} | |||
| if model_prefix == "amazon": | |||
| payload['inputText'] = texts | |||
| @@ -696,12 +696,10 @@ class CohereLargeLanguageModel(LargeLanguageModel): | |||
| en_US=model | |||
| ), | |||
| model_type=ModelType.LLM, | |||
| features=[feature for feature in base_model_schema_features], | |||
| features=list(base_model_schema_features), | |||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||
| model_properties={ | |||
| key: property for key, property in base_model_schema_model_properties.items() | |||
| }, | |||
| parameter_rules=[rule for rule in base_model_schema_parameters_rules], | |||
| model_properties=dict(base_model_schema_model_properties.items()), | |||
| parameter_rules=list(base_model_schema_parameters_rules), | |||
| pricing=base_model_schema.pricing | |||
| ) | |||
| @@ -277,10 +277,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): | |||
| type='function', | |||
| function=AssistantPromptMessage.ToolCall.ToolCallFunction( | |||
| name=part.function_call.name, | |||
| arguments=json.dumps({ | |||
| key: value | |||
| for key, value in part.function_call.args.items() | |||
| }) | |||
| arguments=json.dumps(dict(part.function_call.args.items())) | |||
| ) | |||
| ) | |||
| ] | |||
| @@ -88,9 +88,9 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): | |||
| def _add_function_call(self, model: str, credentials: dict) -> None: | |||
| model_schema = self.get_model_schema(model, credentials) | |||
| if model_schema and set([ | |||
| if model_schema and { | |||
| ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL | |||
| ]).intersection(model_schema.features or []): | |||
| }.intersection(model_schema.features or []): | |||
| credentials['function_calling_type'] = 'tool_call' | |||
| def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: | |||
| @@ -100,10 +100,10 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): | |||
| if api_key: | |||
| headers["Authorization"] = f"Bearer {api_key}" | |||
| endpoint_url = credentials['endpoint_url'] if 'endpoint_url' in credentials else None | |||
| endpoint_url = credentials.get('endpoint_url') | |||
| if endpoint_url and not endpoint_url.endswith('/'): | |||
| endpoint_url += '/' | |||
| server_url = credentials['server_url'] if 'server_url' in credentials else None | |||
| server_url = credentials.get('server_url') | |||
| # prepare the payload for a simple ping to the model | |||
| data = { | |||
| @@ -182,10 +182,10 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): | |||
| if stream: | |||
| headers['Accept'] = 'text/event-stream' | |||
| endpoint_url = credentials['endpoint_url'] if 'endpoint_url' in credentials else None | |||
| endpoint_url = credentials.get('endpoint_url') | |||
| if endpoint_url and not endpoint_url.endswith('/'): | |||
| endpoint_url += '/' | |||
| server_url = credentials['server_url'] if 'server_url' in credentials else None | |||
| server_url = credentials.get('server_url') | |||
| data = { | |||
| "model": model, | |||
| @@ -1073,12 +1073,10 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): | |||
| en_US=model | |||
| ), | |||
| model_type=ModelType.LLM, | |||
| features=[feature for feature in base_model_schema_features], | |||
| features=list(base_model_schema_features), | |||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||
| model_properties={ | |||
| key: property for key, property in base_model_schema_model_properties.items() | |||
| }, | |||
| parameter_rules=[rule for rule in base_model_schema_parameters_rules], | |||
| model_properties=dict(base_model_schema_model_properties.items()), | |||
| parameter_rules=list(base_model_schema_parameters_rules), | |||
| pricing=base_model_schema.pricing | |||
| ) | |||
| @@ -80,7 +80,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): | |||
| max_workers = self._get_model_workers_limit(model, credentials) | |||
| try: | |||
| sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) | |||
| audio_bytes_list = list() | |||
| audio_bytes_list = [] | |||
| # Create a thread pool and map the function to the list of sentences | |||
| with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: | |||
| @@ -275,14 +275,13 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): | |||
| @classmethod | |||
| def _get_parameter_type(cls, param_type: str) -> str: | |||
| if param_type == 'integer': | |||
| return 'int' | |||
| elif param_type == 'number': | |||
| return 'float' | |||
| elif param_type == 'boolean': | |||
| return 'boolean' | |||
| elif param_type == 'string': | |||
| return 'string' | |||
| type_mapping = { | |||
| 'integer': 'int', | |||
| 'number': 'float', | |||
| 'boolean': 'boolean', | |||
| 'string': 'string' | |||
| } | |||
| return type_mapping.get(param_type) | |||
| def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: | |||
| messages = messages.copy() # don't mutate the original list | |||
| @@ -80,7 +80,7 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): | |||
| max_workers = self._get_model_workers_limit(model, credentials) | |||
| try: | |||
| sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) | |||
| audio_bytes_list = list() | |||
| audio_bytes_list = [] | |||
| # Create a thread pool and map the function to the list of sentences | |||
| with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: | |||
| @@ -579,10 +579,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): | |||
| type='function', | |||
| function=AssistantPromptMessage.ToolCall.ToolCallFunction( | |||
| name=part.function_call.name, | |||
| arguments=json.dumps({ | |||
| key: value | |||
| for key, value in part.function_call.args.items() | |||
| }) | |||
| arguments=json.dumps(dict(part.function_call.args.items())) | |||
| ) | |||
| ) | |||
| ] | |||
| @@ -102,7 +102,7 @@ class Signer: | |||
| body_hash = Util.sha256(request.body) | |||
| request.headers['X-Content-Sha256'] = body_hash | |||
| signed_headers = dict() | |||
| signed_headers = {} | |||
| for key in request.headers: | |||
| if key in ['Content-Type', 'Content-Md5', 'Host'] or key.startswith('X-'): | |||
| signed_headers[key.lower()] = request.headers[key] | |||
| @@ -150,7 +150,7 @@ class Request: | |||
| self.headers = OrderedDict() | |||
| self.query = OrderedDict() | |||
| self.body = '' | |||
| self.form = dict() | |||
| self.form = {} | |||
| self.connection_timeout = 0 | |||
| self.socket_timeout = 0 | |||
| @@ -147,7 +147,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): | |||
| return self._get_num_tokens_by_gpt2(text) | |||
| if is_completion_model: | |||
| return sum([tokens(str(message.content)) for message in messages]) | |||
| return sum(tokens(str(message.content)) for message in messages) | |||
| tokens_per_message = 3 | |||
| tokens_per_name = 1 | |||
| @@ -18,7 +18,7 @@ class _CommonZhipuaiAI: | |||
| """ | |||
| credentials_kwargs = { | |||
| "api_key": credentials['api_key'] if 'api_key' in credentials else | |||
| credentials['zhipuai_api_key'] if 'zhipuai_api_key' in credentials else None, | |||
| credentials.get("zhipuai_api_key"), | |||
| } | |||
| return credentials_kwargs | |||
| @@ -148,7 +148,7 @@ class SimplePromptTransform(PromptTransform): | |||
| special_variable_keys.append('#histories#') | |||
| if query_in_prompt: | |||
| prompt += prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{#query#}}' | |||
| prompt += prompt_rules.get('query_prompt', '{{#query#}}') | |||
| special_variable_keys.append('#query#') | |||
| return { | |||
| @@ -234,8 +234,8 @@ class SimplePromptTransform(PromptTransform): | |||
| ) | |||
| ), | |||
| max_token_limit=rest_tokens, | |||
| human_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', | |||
| ai_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' | |||
| human_prefix=prompt_rules.get('human_prefix', 'Human'), | |||
| ai_prefix=prompt_rules.get('assistant_prefix', 'Assistant') | |||
| ) | |||
| # get prompt | |||
| @@ -417,7 +417,7 @@ class ProviderManager: | |||
| model_load_balancing_enabled = cache_result == 'True' | |||
| if not model_load_balancing_enabled: | |||
| return dict() | |||
| return {} | |||
| provider_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ | |||
| .filter( | |||
| @@ -451,7 +451,7 @@ class ProviderManager: | |||
| if not provider_records: | |||
| provider_records = [] | |||
| provider_quota_to_provider_record_dict = dict() | |||
| provider_quota_to_provider_record_dict = {} | |||
| for provider_record in provider_records: | |||
| if provider_record.provider_type != ProviderType.SYSTEM.value: | |||
| continue | |||
| @@ -661,7 +661,7 @@ class ProviderManager: | |||
| provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider) | |||
| # Convert provider_records to dict | |||
| quota_type_to_provider_records_dict = dict() | |||
| quota_type_to_provider_records_dict = {} | |||
| for provider_record in provider_records: | |||
| if provider_record.provider_type != ProviderType.SYSTEM.value: | |||
| continue | |||
| @@ -197,7 +197,7 @@ class Jieba(BaseKeyword): | |||
| chunk_indices_count[node_id] += 1 | |||
| sorted_chunk_indices = sorted( | |||
| list(chunk_indices_count.keys()), | |||
| chunk_indices_count.keys(), | |||
| key=lambda x: chunk_indices_count[x], | |||
| reverse=True, | |||
| ) | |||
| @@ -201,7 +201,7 @@ class ReactMultiDatasetRouter: | |||
| tool_strings.append( | |||
| f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}") | |||
| formatted_tools = "\n".join(tool_strings) | |||
| unique_tool_names = set(tool.name for tool in tools) | |||
| unique_tool_names = {tool.name for tool in tools} | |||
| tool_names = ", ".join('"' + name + '"' for name in unique_tool_names) | |||
| format_instructions = format_instructions.format(tool_names=tool_names) | |||
| template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) | |||
| @@ -105,15 +105,15 @@ class BingSearchTool(BuiltinTool): | |||
| def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None: | |||
| key = credentials.get('subscription_key', None) | |||
| key = credentials.get('subscription_key') | |||
| if not key: | |||
| raise Exception('subscription_key is required') | |||
| server_url = credentials.get('server_url', None) | |||
| server_url = credentials.get('server_url') | |||
| if not server_url: | |||
| server_url = self.url | |||
| query = tool_parameters.get('query', None) | |||
| query = tool_parameters.get('query') | |||
| if not query: | |||
| raise Exception('query is required') | |||
| @@ -170,7 +170,7 @@ class BingSearchTool(BuiltinTool): | |||
| if not server_url: | |||
| server_url = self.url | |||
| query = tool_parameters.get('query', None) | |||
| query = tool_parameters.get('query') | |||
| if not query: | |||
| raise Exception('query is required') | |||
| @@ -16,12 +16,12 @@ class BarChartTool(BuiltinTool): | |||
| data = data.split(';') | |||
| # if all data is int, convert to int | |||
| if all([i.isdigit() for i in data]): | |||
| if all(i.isdigit() for i in data): | |||
| data = [int(i) for i in data] | |||
| else: | |||
| data = [float(i) for i in data] | |||
| axis = tool_parameters.get('x_axis', None) or None | |||
| axis = tool_parameters.get('x_axis') or None | |||
| if axis: | |||
| axis = axis.split(';') | |||
| if len(axis) != len(data): | |||
| @@ -17,14 +17,14 @@ class LinearChartTool(BuiltinTool): | |||
| return self.create_text_message('Please input data') | |||
| data = data.split(';') | |||
| axis = tool_parameters.get('x_axis', None) or None | |||
| axis = tool_parameters.get('x_axis') or None | |||
| if axis: | |||
| axis = axis.split(';') | |||
| if len(axis) != len(data): | |||
| axis = None | |||
| # if all data is int, convert to int | |||
| if all([i.isdigit() for i in data]): | |||
| if all(i.isdigit() for i in data): | |||
| data = [int(i) for i in data] | |||
| else: | |||
| data = [float(i) for i in data] | |||
| @@ -16,10 +16,10 @@ class PieChartTool(BuiltinTool): | |||
| if not data: | |||
| return self.create_text_message('Please input data') | |||
| data = data.split(';') | |||
| categories = tool_parameters.get('categories', None) or None | |||
| categories = tool_parameters.get('categories') or None | |||
| # if all data is int, convert to int | |||
| if all([i.isdigit() for i in data]): | |||
| if all(i.isdigit() for i in data): | |||
| data = [int(i) for i in data] | |||
| else: | |||
| data = [float(i) for i in data] | |||
| @@ -37,10 +37,10 @@ class GaodeRepositoriesTool(BuiltinTool): | |||
| apikey=self.runtime.credentials.get('api_key'))) | |||
| weatherInfo_data = weatherInfo_response.json() | |||
| if weatherInfo_response.status_code == 200 and weatherInfo_data.get('info') == 'OK': | |||
| contents = list() | |||
| contents = [] | |||
| if len(weatherInfo_data.get('forecasts')) > 0: | |||
| for item in weatherInfo_data['forecasts'][0]['casts']: | |||
| content = dict() | |||
| content = {} | |||
| content['date'] = item.get('date') | |||
| content['week'] = item.get('week') | |||
| content['dayweather'] = item.get('dayweather') | |||
| @@ -39,10 +39,10 @@ class GihubRepositoriesTool(BuiltinTool): | |||
| f"q={quote(query)}&sort=stars&per_page={top_n}&order=desc") | |||
| response_data = response.json() | |||
| if response.status_code == 200 and isinstance(response_data.get('items'), list): | |||
| contents = list() | |||
| contents = [] | |||
| if len(response_data.get('items')) > 0: | |||
| for item in response_data.get('items'): | |||
| content = dict() | |||
| content = {} | |||
| updated_at_object = datetime.strptime(item['updated_at'], "%Y-%m-%dT%H:%M:%SZ") | |||
| content['owner'] = item['owner']['login'] | |||
| content['name'] = item['name'] | |||
| @@ -26,11 +26,11 @@ class JinaReaderTool(BuiltinTool): | |||
| if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'): | |||
| headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key') | |||
| target_selector = tool_parameters.get('target_selector', None) | |||
| target_selector = tool_parameters.get('target_selector') | |||
| if target_selector is not None and target_selector != '': | |||
| headers['X-Target-Selector'] = target_selector | |||
| wait_for_selector = tool_parameters.get('wait_for_selector', None) | |||
| wait_for_selector = tool_parameters.get('wait_for_selector') | |||
| if wait_for_selector is not None and wait_for_selector != '': | |||
| headers['X-Wait-For-Selector'] = wait_for_selector | |||
| @@ -43,7 +43,7 @@ class JinaReaderTool(BuiltinTool): | |||
| if tool_parameters.get('gather_all_images_at_the_end', False): | |||
| headers['X-With-Images-Summary'] = 'true' | |||
| proxy_server = tool_parameters.get('proxy_server', None) | |||
| proxy_server = tool_parameters.get('proxy_server') | |||
| if proxy_server is not None and proxy_server != '': | |||
| headers['X-Proxy-Url'] = proxy_server | |||
| @@ -33,7 +33,7 @@ class JinaSearchTool(BuiltinTool): | |||
| if tool_parameters.get('gather_all_images_at_the_end', False): | |||
| headers['X-With-Images-Summary'] = 'true' | |||
| proxy_server = tool_parameters.get('proxy_server', None) | |||
| proxy_server = tool_parameters.get('proxy_server') | |||
| if proxy_server is not None and proxy_server != '': | |||
| headers['X-Proxy-Url'] = proxy_server | |||
| @@ -94,7 +94,7 @@ class GoogleTool(BuiltinTool): | |||
| google_domain = tool_parameters.get("google_domain", "google.com") | |||
| gl = tool_parameters.get("gl", "us") | |||
| hl = tool_parameters.get("hl", "en") | |||
| location = tool_parameters.get("location", None) | |||
| location = tool_parameters.get("location") | |||
| api_key = self.runtime.credentials['searchapi_api_key'] | |||
| result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location) | |||
| @@ -72,11 +72,11 @@ class GoogleJobsTool(BuiltinTool): | |||
| """ | |||
| query = tool_parameters['query'] | |||
| result_type = tool_parameters['result_type'] | |||
| is_remote = tool_parameters.get("is_remote", None) | |||
| is_remote = tool_parameters.get("is_remote") | |||
| google_domain = tool_parameters.get("google_domain", "google.com") | |||
| gl = tool_parameters.get("gl", "us") | |||
| hl = tool_parameters.get("hl", "en") | |||
| location = tool_parameters.get("location", None) | |||
| location = tool_parameters.get("location") | |||
| ltype = 1 if is_remote else None | |||
| @@ -82,7 +82,7 @@ class GoogleNewsTool(BuiltinTool): | |||
| google_domain = tool_parameters.get("google_domain", "google.com") | |||
| gl = tool_parameters.get("gl", "us") | |||
| hl = tool_parameters.get("hl", "en") | |||
| location = tool_parameters.get("location", None) | |||
| location = tool_parameters.get("location") | |||
| api_key = self.runtime.credentials['searchapi_api_key'] | |||
| result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location) | |||
| @@ -107,7 +107,7 @@ class SearXNGSearchTool(BuiltinTool): | |||
| if not host: | |||
| raise Exception('SearXNG api is required') | |||
| query = tool_parameters.get('query', None) | |||
| query = tool_parameters.get('query') | |||
| if not query: | |||
| return self.create_text_message('Please input query') | |||
| @@ -43,7 +43,7 @@ class GetMarkdownTool(BuiltinTool): | |||
| Invoke the SerplyApi tool. | |||
| """ | |||
| url = tool_parameters["url"] | |||
| location = tool_parameters.get("location", None) | |||
| location = tool_parameters.get("location") | |||
| api_key = self.runtime.credentials["serply_api_key"] | |||
| result = SerplyApi(api_key).run(url, location=location) | |||
| @@ -55,7 +55,7 @@ class SerplyApi: | |||
| f"Employer: {job['employer']}", | |||
| f"Location: {job['location']}", | |||
| f"Link: {job['link']}", | |||
| f"""Highest: {", ".join([h for h in job["highlights"]])}""", | |||
| f"""Highest: {", ".join(list(job["highlights"]))}""", | |||
| "---", | |||
| ]) | |||
| ) | |||
| @@ -78,7 +78,7 @@ class JobSearchTool(BuiltinTool): | |||
| query = tool_parameters["query"] | |||
| gl = tool_parameters.get("gl", "us") | |||
| hl = tool_parameters.get("hl", "en") | |||
| location = tool_parameters.get("location", None) | |||
| location = tool_parameters.get("location") | |||
| api_key = self.runtime.credentials["serply_api_key"] | |||
| result = SerplyApi(api_key).run(query, gl=gl, hl=hl, location=location) | |||
| @@ -80,7 +80,7 @@ class NewsSearchTool(BuiltinTool): | |||
| query = tool_parameters["query"] | |||
| gl = tool_parameters.get("gl", "us") | |||
| hl = tool_parameters.get("hl", "en") | |||
| location = tool_parameters.get("location", None) | |||
| location = tool_parameters.get("location") | |||
| api_key = self.runtime.credentials["serply_api_key"] | |||
| result = SerplyApi(api_key).run(query, gl=gl, hl=hl, location=location) | |||
| @@ -83,7 +83,7 @@ class ScholarSearchTool(BuiltinTool): | |||
| query = tool_parameters["query"] | |||
| gl = tool_parameters.get("gl", "us") | |||
| hl = tool_parameters.get("hl", "en") | |||
| location = tool_parameters.get("location", None) | |||
| location = tool_parameters.get("location") | |||
| api_key = self.runtime.credentials["serply_api_key"] | |||
| result = SerplyApi(api_key).run(query, gl=gl, hl=hl, location=location) | |||
| @@ -38,7 +38,7 @@ class BuiltinToolProviderController(ToolProviderController): | |||
| super().__init__(**{ | |||
| 'identity': provider_yaml['identity'], | |||
| 'credentials_schema': provider_yaml['credentials_for_provider'] if 'credentials_for_provider' in provider_yaml else None, | |||
| 'credentials_schema': provider_yaml.get('credentials_for_provider', None), | |||
| }) | |||
| def _get_builtin_tools(self) -> list[Tool]: | |||
| @@ -159,8 +159,8 @@ class ApiTool(Tool): | |||
| for content_type in self.api_bundle.openapi['requestBody']['content']: | |||
| headers['Content-Type'] = content_type | |||
| body_schema = self.api_bundle.openapi['requestBody']['content'][content_type]['schema'] | |||
| required = body_schema['required'] if 'required' in body_schema else [] | |||
| properties = body_schema['properties'] if 'properties' in body_schema else {} | |||
| required = body_schema.get('required', []) | |||
| properties = body_schema.get('properties', {}) | |||
| for name, property in properties.items(): | |||
| if name in parameters: | |||
| # convert type | |||
| @@ -90,7 +90,7 @@ class DatasetRetrieverTool(Tool): | |||
| """ | |||
| invoke dataset retriever tool | |||
| """ | |||
| query = tool_parameters.get('query', None) | |||
| query = tool_parameters.get('query') | |||
| if not query: | |||
| return self.create_text_message(text='please input query') | |||
| @@ -209,7 +209,7 @@ class ToolManager: | |||
| if parameter_rule.type == ToolParameter.ToolParameterType.SELECT: | |||
| # check if tool_parameter_config in options | |||
| options = list(map(lambda x: x.value, parameter_rule.options)) | |||
| options = [x.value for x in parameter_rule.options] | |||
| if parameter_value is not None and parameter_value not in options: | |||
| raise ValueError( | |||
| f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}") | |||
| @@ -21,10 +21,7 @@ class ApiBasedToolSchemaParser: | |||
| extra_info = extra_info if extra_info is not None else {} | |||
| # set description to extra_info | |||
| if 'description' in openapi['info']: | |||
| extra_info['description'] = openapi['info']['description'] | |||
| else: | |||
| extra_info['description'] = '' | |||
| extra_info['description'] = openapi['info'].get('description', '') | |||
| if len(openapi['servers']) == 0: | |||
| raise ToolProviderNotFoundError('No server found in the openapi yaml.') | |||
| @@ -95,8 +92,8 @@ class ApiBasedToolSchemaParser: | |||
| # parse body parameters | |||
| if 'schema' in interface['operation']['requestBody']['content'][content_type]: | |||
| body_schema = interface['operation']['requestBody']['content'][content_type]['schema'] | |||
| required = body_schema['required'] if 'required' in body_schema else [] | |||
| properties = body_schema['properties'] if 'properties' in body_schema else {} | |||
| required = body_schema.get('required', []) | |||
| properties = body_schema.get('properties', {}) | |||
| for name, property in properties.items(): | |||
| tool = ToolParameter( | |||
| name=name, | |||
| @@ -105,14 +102,14 @@ class ApiBasedToolSchemaParser: | |||
| zh_Hans=name | |||
| ), | |||
| human_description=I18nObject( | |||
| en_US=property['description'] if 'description' in property else '', | |||
| zh_Hans=property['description'] if 'description' in property else '' | |||
| en_US=property.get('description', ''), | |||
| zh_Hans=property.get('description', '') | |||
| ), | |||
| type=ToolParameter.ToolParameterType.STRING, | |||
| required=name in required, | |||
| form=ToolParameter.ToolParameterForm.LLM, | |||
| llm_description=property['description'] if 'description' in property else '', | |||
| default=property['default'] if 'default' in property else None, | |||
| llm_description=property.get('description', ''), | |||
| default=property.get('default', None), | |||
| ) | |||
| # check if there is a type | |||
| @@ -149,7 +146,7 @@ class ApiBasedToolSchemaParser: | |||
| server_url=server_url + interface['path'], | |||
| method=interface['method'], | |||
| summary=interface['operation']['description'] if 'description' in interface['operation'] else | |||
| interface['operation']['summary'] if 'summary' in interface['operation'] else None, | |||
| interface['operation'].get('summary', None), | |||
| operation_id=interface['operation']['operationId'], | |||
| parameters=parameters, | |||
| author='', | |||
| @@ -283,7 +283,7 @@ def strip_control_characters(text): | |||
| # [Cn]: Other, Not Assigned | |||
| # [Co]: Other, Private Use | |||
| # [Cs]: Other, Surrogate | |||
| control_chars = set(['Cc', 'Cf', 'Cn', 'Co', 'Cs']) | |||
| control_chars = {'Cc', 'Cf', 'Cn', 'Co', 'Cs'} | |||
| retained_chars = ['\t', '\n', '\r', '\f'] | |||
| # Remove non-printing control characters | |||
| @@ -93,7 +93,7 @@ class ParameterExtractorNode(LLMNode): | |||
| # fetch memory | |||
| memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) | |||
| if set(model_schema.features or []) & set([ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]) \ | |||
| if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \ | |||
| and node_data.reasoning_mode == 'function_call': | |||
| # use function call | |||
| prompt_messages, prompt_message_tools = self._generate_function_call_prompt( | |||
| @@ -644,7 +644,7 @@ class ParameterExtractorNode(LLMNode): | |||
| if not model_schema: | |||
| raise ValueError("Model schema not found") | |||
| if set(model_schema.features or []) & set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]): | |||
| if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: | |||
| prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) | |||
| else: | |||
| prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000) | |||
| @@ -246,10 +246,7 @@ class NotionOAuth(OAuthDataSource): | |||
| } | |||
| response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) | |||
| response_json = response.json() | |||
| if 'results' in response_json: | |||
| results = response_json['results'] | |||
| else: | |||
| results = [] | |||
| results = response_json.get('results', []) | |||
| return results | |||
| def notion_block_parent_page_id(self, access_token: str, block_id: str): | |||
| @@ -293,8 +290,5 @@ class NotionOAuth(OAuthDataSource): | |||
| } | |||
| response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) | |||
| response_json = response.json() | |||
| if 'results' in response_json: | |||
| results = response_json['results'] | |||
| else: | |||
| results = [] | |||
| results = response_json.get('results', []) | |||
| return results | |||
| @@ -14,9 +14,11 @@ line-length = 120 | |||
| preview = true | |||
| select = [ | |||
| "B", # flake8-bugbear rules | |||
| "C4", # flake8-comprehensions | |||
| "F", # pyflakes rules | |||
| "I", # isort rules | |||
| "UP", # pyupgrade rules | |||
| "UP", # pyupgrade rules | |||
| "B035", # static-key-dict-comprehension | |||
| "E101", # mixed-spaces-and-tabs | |||
| "E111", # indentation-with-invalid-multiple | |||
| "E112", # no-indented-block | |||
| @@ -28,8 +30,13 @@ select = [ | |||
| "RUF100", # unused-noqa | |||
| "RUF101", # redirected-noqa | |||
| "S506", # unsafe-yaml-load | |||
| "SIM116", # if-else-block-instead-of-dict-lookup | |||
| "SIM401", # if-else-block-instead-of-dict-get | |||
| "SIM910", # dict-get-with-none-default | |||
| "W191", # tab-indentation | |||
| "W605", # invalid-escape-sequence | |||
| "F601", # multi-value-repeated-key-literal | |||
| "F602", # multi-value-repeated-key-variable | |||
| ] | |||
| ignore = [ | |||
| "F403", # undefined-local-with-import-star | |||
| @@ -82,8 +89,8 @@ HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = "b" | |||
| HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL = "c" | |||
| MOCK_SWITCH = "true" | |||
| CODE_MAX_STRING_LENGTH = "80000" | |||
| CODE_EXECUTION_ENDPOINT="http://127.0.0.1:8194" | |||
| CODE_EXECUTION_API_KEY="dify-sandbox" | |||
| CODE_EXECUTION_ENDPOINT = "http://127.0.0.1:8194" | |||
| CODE_EXECUTION_API_KEY = "dify-sandbox" | |||
| FIRECRAWL_API_KEY = "fc-" | |||
| [tool.poetry] | |||
| @@ -114,11 +121,11 @@ cachetools = "~5.3.0" | |||
| weaviate-client = "~3.21.0" | |||
| mailchimp-transactional = "~1.0.50" | |||
| scikit-learn = "1.2.2" | |||
| sentry-sdk = {version = "~1.39.2", extras = ["flask"]} | |||
| sentry-sdk = { version = "~1.39.2", extras = ["flask"] } | |||
| sympy = "1.12" | |||
| jieba = "0.42.1" | |||
| celery = "~5.3.6" | |||
| redis = {version = "~5.0.3", extras = ["hiredis"]} | |||
| redis = { version = "~5.0.3", extras = ["hiredis"] } | |||
| chardet = "~5.1.0" | |||
| python-docx = "~1.1.0" | |||
| pypdfium2 = "~4.17.0" | |||
| @@ -138,7 +145,7 @@ googleapis-common-protos = "1.63.0" | |||
| google-cloud-storage = "2.16.0" | |||
| replicate = "~0.22.0" | |||
| websocket-client = "~1.7.0" | |||
| dashscope = {version = "~1.17.0", extras = ["tokenizer"]} | |||
| dashscope = { version = "~1.17.0", extras = ["tokenizer"] } | |||
| huggingface-hub = "~0.16.4" | |||
| transformers = "~4.35.0" | |||
| tokenizers = "~0.15.0" | |||
| @@ -152,10 +159,10 @@ qdrant-client = "1.7.3" | |||
| cohere = "~5.2.4" | |||
| pyyaml = "~6.0.1" | |||
| numpy = "~1.26.4" | |||
| unstructured = {version = "~0.10.27", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"]} | |||
| unstructured = { version = "~0.10.27", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] } | |||
| bs4 = "~0.0.1" | |||
| markdown = "~3.5.1" | |||
| httpx = {version = "~0.27.0", extras = ["socks"]} | |||
| httpx = { version = "~0.27.0", extras = ["socks"] } | |||
| matplotlib = "~3.8.2" | |||
| yfinance = "~0.2.40" | |||
| pydub = "~0.25.1" | |||
| @@ -180,7 +187,7 @@ pgvector = "0.2.5" | |||
| pymysql = "1.1.1" | |||
| tidb-vector = "0.0.9" | |||
| google-cloud-aiplatform = "1.49.0" | |||
| vanna = {version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"]} | |||
| vanna = { version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"] } | |||
| kaleido = "0.2.1" | |||
| tencentcloud-sdk-python-hunyuan = "~3.0.1158" | |||
| tcvectordb = "1.3.2" | |||
| @@ -696,7 +696,7 @@ class DocumentService: | |||
| elif document_data["data_source"]["type"] == "notion_import": | |||
| notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] | |||
| exist_page_ids = [] | |||
| exist_document = dict() | |||
| exist_document = {} | |||
| documents = Document.query.filter_by( | |||
| dataset_id=dataset.id, | |||
| tenant_id=current_user.current_tenant_id, | |||
| @@ -95,7 +95,7 @@ class RecommendedAppService: | |||
| categories.add(recommended_app.category) # add category to categories | |||
| return {'recommended_apps': recommended_apps_result, 'categories': sorted(list(categories))} | |||
| return {'recommended_apps': recommended_apps_result, 'categories': sorted(categories)} | |||
| @classmethod | |||
| def _fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: | |||
| @@ -514,8 +514,8 @@ class WorkflowConverter: | |||
| prompt_rules = prompt_template_config['prompt_rules'] | |||
| role_prefix = { | |||
| "user": prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', | |||
| "assistant": prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' | |||
| "user": prompt_rules.get('human_prefix', 'Human'), | |||
| "assistant": prompt_rules.get('assistant_prefix', 'Assistant') | |||
| } | |||
| else: | |||
| advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template | |||
| @@ -112,7 +112,7 @@ def test_execute_llm(setup_openai_mock): | |||
| # Mock db.session.close() | |||
| db.session.close = MagicMock() | |||
| node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config])) | |||
| node._fetch_model_config = MagicMock(return_value=(model_instance, model_config)) | |||
| # execute node | |||
| result = node.run(pool) | |||
| @@ -229,7 +229,7 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock): | |||
| # Mock db.session.close() | |||
| db.session.close = MagicMock() | |||
| node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config])) | |||
| node._fetch_model_config = MagicMock(return_value=(model_instance, model_config)) | |||
| # execute node | |||
| result = node.run(pool) | |||
| @@ -59,7 +59,7 @@ def get_mocked_fetch_model_config( | |||
| provider_model_bundle=provider_model_bundle | |||
| ) | |||
| return MagicMock(return_value=tuple([model_instance, model_config])) | |||
| return MagicMock(return_value=(model_instance, model_config)) | |||
| @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) | |||
| def test_function_calling_parameter_extractor(setup_openai_mock): | |||
| @@ -238,8 +238,8 @@ def test__get_completion_model_prompt_messages(): | |||
| prompt_rules = prompt_template['prompt_rules'] | |||
| full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text( | |||
| max_token_limit=2000, | |||
| human_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', | |||
| ai_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' | |||
| human_prefix=prompt_rules.get("human_prefix", "Human"), | |||
| ai_prefix=prompt_rules.get("assistant_prefix", "Assistant") | |||
| )} | |||
| real_prompt = prompt_template['prompt_template'].format(full_inputs) | |||