Co-authored-by: -LAN- <laipz8200@outlook.com>tags/0.6.12
| 'provider_type': tool['provider_type'], | 'provider_type': tool['provider_type'], | ||||
| 'provider_id': tool['provider_id'], | 'provider_id': tool['provider_id'], | ||||
| 'tool_name': tool['tool_name'], | 'tool_name': tool['tool_name'], | ||||
| 'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {} | |||||
| 'tool_parameters': tool.get('tool_parameters', {}) | |||||
| } | } | ||||
| agent_tools.append(AgentToolEntity(**agent_tool_properties)) | agent_tools.append(AgentToolEntity(**agent_tool_properties)) |
| inputs = args['inputs'] | inputs = args['inputs'] | ||||
| extras = { | extras = { | ||||
| "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else False | |||||
| "auto_generate_conversation_name": args.get('auto_generate_name', False) | |||||
| } | } | ||||
| # get conversation | # get conversation |
| inputs = args['inputs'] | inputs = args['inputs'] | ||||
| extras = { | extras = { | ||||
| "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True | |||||
| "auto_generate_conversation_name": args.get('auto_generate_name', True) | |||||
| } | } | ||||
| # get conversation | # get conversation |
| llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) | llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) | ||||
| model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) | model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) | ||||
| if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []): | |||||
| if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): | |||||
| agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING | agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING | ||||
| conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() | conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() |
| inputs = args['inputs'] | inputs = args['inputs'] | ||||
| extras = { | extras = { | ||||
| "auto_generate_conversation_name": args['auto_generate_name'] if 'auto_generate_name' in args else True | |||||
| "auto_generate_conversation_name": args.get('auto_generate_name', True) | |||||
| } | } | ||||
| # get conversation | # get conversation |
| original_provider_configurate_methods[self.provider.provider].append(configurate_method) | original_provider_configurate_methods[self.provider.provider].append(configurate_method) | ||||
| if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: | if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: | ||||
| if (any([len(quota_configuration.restrict_models) > 0 | |||||
| for quota_configuration in self.system_configuration.quota_configurations]) | |||||
| if (any(len(quota_configuration.restrict_models) > 0 | |||||
| for quota_configuration in self.system_configuration.quota_configurations) | |||||
| and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods): | and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods): | ||||
| self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) | self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) | ||||
| document_id=dataset_document.id, | document_id=dataset_document.id, | ||||
| after_indexing_status="splitting", | after_indexing_status="splitting", | ||||
| extra_update_params={ | extra_update_params={ | ||||
| DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]), | |||||
| DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs), | |||||
| DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | ||||
| } | } | ||||
| ) | ) |
| max_workers = self._get_model_workers_limit(model, credentials) | max_workers = self._get_model_workers_limit(model, credentials) | ||||
| try: | try: | ||||
| sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) | sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) | ||||
| audio_bytes_list = list() | |||||
| audio_bytes_list = [] | |||||
| # Create a thread pool and map the function to the list of sentences | # Create a thread pool and map the function to the list of sentences | ||||
| with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: | with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: |
| # - https://docs.anthropic.com/claude/reference/claude-on-amazon-bedrock | # - https://docs.anthropic.com/claude/reference/claude-on-amazon-bedrock | ||||
| # - https://github.com/anthropics/anthropic-sdk-python | # - https://github.com/anthropics/anthropic-sdk-python | ||||
| client = AnthropicBedrock( | client = AnthropicBedrock( | ||||
| aws_access_key=credentials.get("aws_access_key_id", None), | |||||
| aws_secret_key=credentials.get("aws_secret_access_key", None), | |||||
| aws_access_key=credentials.get("aws_access_key_id"), | |||||
| aws_secret_key=credentials.get("aws_secret_access_key"), | |||||
| aws_region=credentials["aws_region"], | aws_region=credentials["aws_region"], | ||||
| ) | ) | ||||
| """ | """ | ||||
| Create payload for bedrock api call depending on model provider | Create payload for bedrock api call depending on model provider | ||||
| """ | """ | ||||
| payload = dict() | |||||
| payload = {} | |||||
| model_prefix = model.split('.')[0] | model_prefix = model.split('.')[0] | ||||
| model_name = model.split('.')[1] | model_name = model.split('.')[1] | ||||
| runtime_client = boto3.client( | runtime_client = boto3.client( | ||||
| service_name='bedrock-runtime', | service_name='bedrock-runtime', | ||||
| config=client_config, | config=client_config, | ||||
| aws_access_key_id=credentials.get("aws_access_key_id", None), | |||||
| aws_secret_access_key=credentials.get("aws_secret_access_key", None) | |||||
| aws_access_key_id=credentials.get("aws_access_key_id"), | |||||
| aws_secret_access_key=credentials.get("aws_secret_access_key") | |||||
| ) | ) | ||||
| model_prefix = model.split('.')[0] | model_prefix = model.split('.')[0] |
| bedrock_runtime = boto3.client( | bedrock_runtime = boto3.client( | ||||
| service_name='bedrock-runtime', | service_name='bedrock-runtime', | ||||
| config=client_config, | config=client_config, | ||||
| aws_access_key_id=credentials.get("aws_access_key_id", None), | |||||
| aws_secret_access_key=credentials.get("aws_secret_access_key", None) | |||||
| aws_access_key_id=credentials.get("aws_access_key_id"), | |||||
| aws_secret_access_key=credentials.get("aws_secret_access_key") | |||||
| ) | ) | ||||
| embeddings = [] | embeddings = [] | ||||
| """ | """ | ||||
| Create payload for bedrock api call depending on model provider | Create payload for bedrock api call depending on model provider | ||||
| """ | """ | ||||
| payload = dict() | |||||
| payload = {} | |||||
| if model_prefix == "amazon": | if model_prefix == "amazon": | ||||
| payload['inputText'] = texts | payload['inputText'] = texts |
| en_US=model | en_US=model | ||||
| ), | ), | ||||
| model_type=ModelType.LLM, | model_type=ModelType.LLM, | ||||
| features=[feature for feature in base_model_schema_features], | |||||
| features=list(base_model_schema_features), | |||||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | ||||
| model_properties={ | |||||
| key: property for key, property in base_model_schema_model_properties.items() | |||||
| }, | |||||
| parameter_rules=[rule for rule in base_model_schema_parameters_rules], | |||||
| model_properties=dict(base_model_schema_model_properties.items()), | |||||
| parameter_rules=list(base_model_schema_parameters_rules), | |||||
| pricing=base_model_schema.pricing | pricing=base_model_schema.pricing | ||||
| ) | ) | ||||
| type='function', | type='function', | ||||
| function=AssistantPromptMessage.ToolCall.ToolCallFunction( | function=AssistantPromptMessage.ToolCall.ToolCallFunction( | ||||
| name=part.function_call.name, | name=part.function_call.name, | ||||
| arguments=json.dumps({ | |||||
| key: value | |||||
| for key, value in part.function_call.args.items() | |||||
| }) | |||||
| arguments=json.dumps(dict(part.function_call.args.items())) | |||||
| ) | ) | ||||
| ) | ) | ||||
| ] | ] |
| def _add_function_call(self, model: str, credentials: dict) -> None: | def _add_function_call(self, model: str, credentials: dict) -> None: | ||||
| model_schema = self.get_model_schema(model, credentials) | model_schema = self.get_model_schema(model, credentials) | ||||
| if model_schema and set([ | |||||
| if model_schema and { | |||||
| ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL | ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL | ||||
| ]).intersection(model_schema.features or []): | |||||
| }.intersection(model_schema.features or []): | |||||
| credentials['function_calling_type'] = 'tool_call' | credentials['function_calling_type'] = 'tool_call' | ||||
| def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: | def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: |
| if api_key: | if api_key: | ||||
| headers["Authorization"] = f"Bearer {api_key}" | headers["Authorization"] = f"Bearer {api_key}" | ||||
| endpoint_url = credentials['endpoint_url'] if 'endpoint_url' in credentials else None | |||||
| endpoint_url = credentials.get('endpoint_url') | |||||
| if endpoint_url and not endpoint_url.endswith('/'): | if endpoint_url and not endpoint_url.endswith('/'): | ||||
| endpoint_url += '/' | endpoint_url += '/' | ||||
| server_url = credentials['server_url'] if 'server_url' in credentials else None | |||||
| server_url = credentials.get('server_url') | |||||
| # prepare the payload for a simple ping to the model | # prepare the payload for a simple ping to the model | ||||
| data = { | data = { | ||||
| if stream: | if stream: | ||||
| headers['Accept'] = 'text/event-stream' | headers['Accept'] = 'text/event-stream' | ||||
| endpoint_url = credentials['endpoint_url'] if 'endpoint_url' in credentials else None | |||||
| endpoint_url = credentials.get('endpoint_url') | |||||
| if endpoint_url and not endpoint_url.endswith('/'): | if endpoint_url and not endpoint_url.endswith('/'): | ||||
| endpoint_url += '/' | endpoint_url += '/' | ||||
| server_url = credentials['server_url'] if 'server_url' in credentials else None | |||||
| server_url = credentials.get('server_url') | |||||
| data = { | data = { | ||||
| "model": model, | "model": model, |
| en_US=model | en_US=model | ||||
| ), | ), | ||||
| model_type=ModelType.LLM, | model_type=ModelType.LLM, | ||||
| features=[feature for feature in base_model_schema_features], | |||||
| features=list(base_model_schema_features), | |||||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | ||||
| model_properties={ | |||||
| key: property for key, property in base_model_schema_model_properties.items() | |||||
| }, | |||||
| parameter_rules=[rule for rule in base_model_schema_parameters_rules], | |||||
| model_properties=dict(base_model_schema_model_properties.items()), | |||||
| parameter_rules=list(base_model_schema_parameters_rules), | |||||
| pricing=base_model_schema.pricing | pricing=base_model_schema.pricing | ||||
| ) | ) | ||||
| max_workers = self._get_model_workers_limit(model, credentials) | max_workers = self._get_model_workers_limit(model, credentials) | ||||
| try: | try: | ||||
| sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) | sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) | ||||
| audio_bytes_list = list() | |||||
| audio_bytes_list = [] | |||||
| # Create a thread pool and map the function to the list of sentences | # Create a thread pool and map the function to the list of sentences | ||||
| with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: | with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: |
| @classmethod | @classmethod | ||||
| def _get_parameter_type(cls, param_type: str) -> str: | def _get_parameter_type(cls, param_type: str) -> str: | ||||
| if param_type == 'integer': | |||||
| return 'int' | |||||
| elif param_type == 'number': | |||||
| return 'float' | |||||
| elif param_type == 'boolean': | |||||
| return 'boolean' | |||||
| elif param_type == 'string': | |||||
| return 'string' | |||||
| type_mapping = { | |||||
| 'integer': 'int', | |||||
| 'number': 'float', | |||||
| 'boolean': 'boolean', | |||||
| 'string': 'string' | |||||
| } | |||||
| return type_mapping.get(param_type) | |||||
| def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: | def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: | ||||
| messages = messages.copy() # don't mutate the original list | messages = messages.copy() # don't mutate the original list |
| max_workers = self._get_model_workers_limit(model, credentials) | max_workers = self._get_model_workers_limit(model, credentials) | ||||
| try: | try: | ||||
| sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) | sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) | ||||
| audio_bytes_list = list() | |||||
| audio_bytes_list = [] | |||||
| # Create a thread pool and map the function to the list of sentences | # Create a thread pool and map the function to the list of sentences | ||||
| with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: | with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: |
| type='function', | type='function', | ||||
| function=AssistantPromptMessage.ToolCall.ToolCallFunction( | function=AssistantPromptMessage.ToolCall.ToolCallFunction( | ||||
| name=part.function_call.name, | name=part.function_call.name, | ||||
| arguments=json.dumps({ | |||||
| key: value | |||||
| for key, value in part.function_call.args.items() | |||||
| }) | |||||
| arguments=json.dumps(dict(part.function_call.args.items())) | |||||
| ) | ) | ||||
| ) | ) | ||||
| ] | ] |
| body_hash = Util.sha256(request.body) | body_hash = Util.sha256(request.body) | ||||
| request.headers['X-Content-Sha256'] = body_hash | request.headers['X-Content-Sha256'] = body_hash | ||||
| signed_headers = dict() | |||||
| signed_headers = {} | |||||
| for key in request.headers: | for key in request.headers: | ||||
| if key in ['Content-Type', 'Content-Md5', 'Host'] or key.startswith('X-'): | if key in ['Content-Type', 'Content-Md5', 'Host'] or key.startswith('X-'): | ||||
| signed_headers[key.lower()] = request.headers[key] | signed_headers[key.lower()] = request.headers[key] |
| self.headers = OrderedDict() | self.headers = OrderedDict() | ||||
| self.query = OrderedDict() | self.query = OrderedDict() | ||||
| self.body = '' | self.body = '' | ||||
| self.form = dict() | |||||
| self.form = {} | |||||
| self.connection_timeout = 0 | self.connection_timeout = 0 | ||||
| self.socket_timeout = 0 | self.socket_timeout = 0 | ||||
| return self._get_num_tokens_by_gpt2(text) | return self._get_num_tokens_by_gpt2(text) | ||||
| if is_completion_model: | if is_completion_model: | ||||
| return sum([tokens(str(message.content)) for message in messages]) | |||||
| return sum(tokens(str(message.content)) for message in messages) | |||||
| tokens_per_message = 3 | tokens_per_message = 3 | ||||
| tokens_per_name = 1 | tokens_per_name = 1 |
| """ | """ | ||||
| credentials_kwargs = { | credentials_kwargs = { | ||||
| "api_key": credentials['api_key'] if 'api_key' in credentials else | "api_key": credentials['api_key'] if 'api_key' in credentials else | ||||
| credentials['zhipuai_api_key'] if 'zhipuai_api_key' in credentials else None, | |||||
| credentials.get("zhipuai_api_key"), | |||||
| } | } | ||||
| return credentials_kwargs | return credentials_kwargs |
| special_variable_keys.append('#histories#') | special_variable_keys.append('#histories#') | ||||
| if query_in_prompt: | if query_in_prompt: | ||||
| prompt += prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{#query#}}' | |||||
| prompt += prompt_rules.get('query_prompt', '{{#query#}}') | |||||
| special_variable_keys.append('#query#') | special_variable_keys.append('#query#') | ||||
| return { | return { | ||||
| ) | ) | ||||
| ), | ), | ||||
| max_token_limit=rest_tokens, | max_token_limit=rest_tokens, | ||||
| human_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', | |||||
| ai_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' | |||||
| human_prefix=prompt_rules.get('human_prefix', 'Human'), | |||||
| ai_prefix=prompt_rules.get('assistant_prefix', 'Assistant') | |||||
| ) | ) | ||||
| # get prompt | # get prompt |
| model_load_balancing_enabled = cache_result == 'True' | model_load_balancing_enabled = cache_result == 'True' | ||||
| if not model_load_balancing_enabled: | if not model_load_balancing_enabled: | ||||
| return dict() | |||||
| return {} | |||||
| provider_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ | provider_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ | ||||
| .filter( | .filter( | ||||
| if not provider_records: | if not provider_records: | ||||
| provider_records = [] | provider_records = [] | ||||
| provider_quota_to_provider_record_dict = dict() | |||||
| provider_quota_to_provider_record_dict = {} | |||||
| for provider_record in provider_records: | for provider_record in provider_records: | ||||
| if provider_record.provider_type != ProviderType.SYSTEM.value: | if provider_record.provider_type != ProviderType.SYSTEM.value: | ||||
| continue | continue | ||||
| provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider) | provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider) | ||||
| # Convert provider_records to dict | # Convert provider_records to dict | ||||
| quota_type_to_provider_records_dict = dict() | |||||
| quota_type_to_provider_records_dict = {} | |||||
| for provider_record in provider_records: | for provider_record in provider_records: | ||||
| if provider_record.provider_type != ProviderType.SYSTEM.value: | if provider_record.provider_type != ProviderType.SYSTEM.value: | ||||
| continue | continue |
| chunk_indices_count[node_id] += 1 | chunk_indices_count[node_id] += 1 | ||||
| sorted_chunk_indices = sorted( | sorted_chunk_indices = sorted( | ||||
| list(chunk_indices_count.keys()), | |||||
| chunk_indices_count.keys(), | |||||
| key=lambda x: chunk_indices_count[x], | key=lambda x: chunk_indices_count[x], | ||||
| reverse=True, | reverse=True, | ||||
| ) | ) |
| tool_strings.append( | tool_strings.append( | ||||
| f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}") | f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}") | ||||
| formatted_tools = "\n".join(tool_strings) | formatted_tools = "\n".join(tool_strings) | ||||
| unique_tool_names = set(tool.name for tool in tools) | |||||
| unique_tool_names = {tool.name for tool in tools} | |||||
| tool_names = ", ".join('"' + name + '"' for name in unique_tool_names) | tool_names = ", ".join('"' + name + '"' for name in unique_tool_names) | ||||
| format_instructions = format_instructions.format(tool_names=tool_names) | format_instructions = format_instructions.format(tool_names=tool_names) | ||||
| template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) | template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) |
| def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None: | def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None: | ||||
| key = credentials.get('subscription_key', None) | |||||
| key = credentials.get('subscription_key') | |||||
| if not key: | if not key: | ||||
| raise Exception('subscription_key is required') | raise Exception('subscription_key is required') | ||||
| server_url = credentials.get('server_url', None) | |||||
| server_url = credentials.get('server_url') | |||||
| if not server_url: | if not server_url: | ||||
| server_url = self.url | server_url = self.url | ||||
| query = tool_parameters.get('query', None) | |||||
| query = tool_parameters.get('query') | |||||
| if not query: | if not query: | ||||
| raise Exception('query is required') | raise Exception('query is required') | ||||
| if not server_url: | if not server_url: | ||||
| server_url = self.url | server_url = self.url | ||||
| query = tool_parameters.get('query', None) | |||||
| query = tool_parameters.get('query') | |||||
| if not query: | if not query: | ||||
| raise Exception('query is required') | raise Exception('query is required') | ||||
| data = data.split(';') | data = data.split(';') | ||||
| # if all data is int, convert to int | # if all data is int, convert to int | ||||
| if all([i.isdigit() for i in data]): | |||||
| if all(i.isdigit() for i in data): | |||||
| data = [int(i) for i in data] | data = [int(i) for i in data] | ||||
| else: | else: | ||||
| data = [float(i) for i in data] | data = [float(i) for i in data] | ||||
| axis = tool_parameters.get('x_axis', None) or None | |||||
| axis = tool_parameters.get('x_axis') or None | |||||
| if axis: | if axis: | ||||
| axis = axis.split(';') | axis = axis.split(';') | ||||
| if len(axis) != len(data): | if len(axis) != len(data): |
| return self.create_text_message('Please input data') | return self.create_text_message('Please input data') | ||||
| data = data.split(';') | data = data.split(';') | ||||
| axis = tool_parameters.get('x_axis', None) or None | |||||
| axis = tool_parameters.get('x_axis') or None | |||||
| if axis: | if axis: | ||||
| axis = axis.split(';') | axis = axis.split(';') | ||||
| if len(axis) != len(data): | if len(axis) != len(data): | ||||
| axis = None | axis = None | ||||
| # if all data is int, convert to int | # if all data is int, convert to int | ||||
| if all([i.isdigit() for i in data]): | |||||
| if all(i.isdigit() for i in data): | |||||
| data = [int(i) for i in data] | data = [int(i) for i in data] | ||||
| else: | else: | ||||
| data = [float(i) for i in data] | data = [float(i) for i in data] |
| if not data: | if not data: | ||||
| return self.create_text_message('Please input data') | return self.create_text_message('Please input data') | ||||
| data = data.split(';') | data = data.split(';') | ||||
| categories = tool_parameters.get('categories', None) or None | |||||
| categories = tool_parameters.get('categories') or None | |||||
| # if all data is int, convert to int | # if all data is int, convert to int | ||||
| if all([i.isdigit() for i in data]): | |||||
| if all(i.isdigit() for i in data): | |||||
| data = [int(i) for i in data] | data = [int(i) for i in data] | ||||
| else: | else: | ||||
| data = [float(i) for i in data] | data = [float(i) for i in data] |
| apikey=self.runtime.credentials.get('api_key'))) | apikey=self.runtime.credentials.get('api_key'))) | ||||
| weatherInfo_data = weatherInfo_response.json() | weatherInfo_data = weatherInfo_response.json() | ||||
| if weatherInfo_response.status_code == 200 and weatherInfo_data.get('info') == 'OK': | if weatherInfo_response.status_code == 200 and weatherInfo_data.get('info') == 'OK': | ||||
| contents = list() | |||||
| contents = [] | |||||
| if len(weatherInfo_data.get('forecasts')) > 0: | if len(weatherInfo_data.get('forecasts')) > 0: | ||||
| for item in weatherInfo_data['forecasts'][0]['casts']: | for item in weatherInfo_data['forecasts'][0]['casts']: | ||||
| content = dict() | |||||
| content = {} | |||||
| content['date'] = item.get('date') | content['date'] = item.get('date') | ||||
| content['week'] = item.get('week') | content['week'] = item.get('week') | ||||
| content['dayweather'] = item.get('dayweather') | content['dayweather'] = item.get('dayweather') |
| f"q={quote(query)}&sort=stars&per_page={top_n}&order=desc") | f"q={quote(query)}&sort=stars&per_page={top_n}&order=desc") | ||||
| response_data = response.json() | response_data = response.json() | ||||
| if response.status_code == 200 and isinstance(response_data.get('items'), list): | if response.status_code == 200 and isinstance(response_data.get('items'), list): | ||||
| contents = list() | |||||
| contents = [] | |||||
| if len(response_data.get('items')) > 0: | if len(response_data.get('items')) > 0: | ||||
| for item in response_data.get('items'): | for item in response_data.get('items'): | ||||
| content = dict() | |||||
| content = {} | |||||
| updated_at_object = datetime.strptime(item['updated_at'], "%Y-%m-%dT%H:%M:%SZ") | updated_at_object = datetime.strptime(item['updated_at'], "%Y-%m-%dT%H:%M:%SZ") | ||||
| content['owner'] = item['owner']['login'] | content['owner'] = item['owner']['login'] | ||||
| content['name'] = item['name'] | content['name'] = item['name'] |
| if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'): | if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'): | ||||
| headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key') | headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key') | ||||
| target_selector = tool_parameters.get('target_selector', None) | |||||
| target_selector = tool_parameters.get('target_selector') | |||||
| if target_selector is not None and target_selector != '': | if target_selector is not None and target_selector != '': | ||||
| headers['X-Target-Selector'] = target_selector | headers['X-Target-Selector'] = target_selector | ||||
| wait_for_selector = tool_parameters.get('wait_for_selector', None) | |||||
| wait_for_selector = tool_parameters.get('wait_for_selector') | |||||
| if wait_for_selector is not None and wait_for_selector != '': | if wait_for_selector is not None and wait_for_selector != '': | ||||
| headers['X-Wait-For-Selector'] = wait_for_selector | headers['X-Wait-For-Selector'] = wait_for_selector | ||||
| if tool_parameters.get('gather_all_images_at_the_end', False): | if tool_parameters.get('gather_all_images_at_the_end', False): | ||||
| headers['X-With-Images-Summary'] = 'true' | headers['X-With-Images-Summary'] = 'true' | ||||
| proxy_server = tool_parameters.get('proxy_server', None) | |||||
| proxy_server = tool_parameters.get('proxy_server') | |||||
| if proxy_server is not None and proxy_server != '': | if proxy_server is not None and proxy_server != '': | ||||
| headers['X-Proxy-Url'] = proxy_server | headers['X-Proxy-Url'] = proxy_server | ||||
| if tool_parameters.get('gather_all_images_at_the_end', False): | if tool_parameters.get('gather_all_images_at_the_end', False): | ||||
| headers['X-With-Images-Summary'] = 'true' | headers['X-With-Images-Summary'] = 'true' | ||||
| proxy_server = tool_parameters.get('proxy_server', None) | |||||
| proxy_server = tool_parameters.get('proxy_server') | |||||
| if proxy_server is not None and proxy_server != '': | if proxy_server is not None and proxy_server != '': | ||||
| headers['X-Proxy-Url'] = proxy_server | headers['X-Proxy-Url'] = proxy_server | ||||
| google_domain = tool_parameters.get("google_domain", "google.com") | google_domain = tool_parameters.get("google_domain", "google.com") | ||||
| gl = tool_parameters.get("gl", "us") | gl = tool_parameters.get("gl", "us") | ||||
| hl = tool_parameters.get("hl", "en") | hl = tool_parameters.get("hl", "en") | ||||
| location = tool_parameters.get("location", None) | |||||
| location = tool_parameters.get("location") | |||||
| api_key = self.runtime.credentials['searchapi_api_key'] | api_key = self.runtime.credentials['searchapi_api_key'] | ||||
| result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location) | result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location) |
| """ | """ | ||||
| query = tool_parameters['query'] | query = tool_parameters['query'] | ||||
| result_type = tool_parameters['result_type'] | result_type = tool_parameters['result_type'] | ||||
| is_remote = tool_parameters.get("is_remote", None) | |||||
| is_remote = tool_parameters.get("is_remote") | |||||
| google_domain = tool_parameters.get("google_domain", "google.com") | google_domain = tool_parameters.get("google_domain", "google.com") | ||||
| gl = tool_parameters.get("gl", "us") | gl = tool_parameters.get("gl", "us") | ||||
| hl = tool_parameters.get("hl", "en") | hl = tool_parameters.get("hl", "en") | ||||
| location = tool_parameters.get("location", None) | |||||
| location = tool_parameters.get("location") | |||||
| ltype = 1 if is_remote else None | ltype = 1 if is_remote else None | ||||
| google_domain = tool_parameters.get("google_domain", "google.com") | google_domain = tool_parameters.get("google_domain", "google.com") | ||||
| gl = tool_parameters.get("gl", "us") | gl = tool_parameters.get("gl", "us") | ||||
| hl = tool_parameters.get("hl", "en") | hl = tool_parameters.get("hl", "en") | ||||
| location = tool_parameters.get("location", None) | |||||
| location = tool_parameters.get("location") | |||||
| api_key = self.runtime.credentials['searchapi_api_key'] | api_key = self.runtime.credentials['searchapi_api_key'] | ||||
| result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location) | result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location) |
| if not host: | if not host: | ||||
| raise Exception('SearXNG api is required') | raise Exception('SearXNG api is required') | ||||
| query = tool_parameters.get('query', None) | |||||
| query = tool_parameters.get('query') | |||||
| if not query: | if not query: | ||||
| return self.create_text_message('Please input query') | return self.create_text_message('Please input query') | ||||
| Invoke the SerplyApi tool. | Invoke the SerplyApi tool. | ||||
| """ | """ | ||||
| url = tool_parameters["url"] | url = tool_parameters["url"] | ||||
| location = tool_parameters.get("location", None) | |||||
| location = tool_parameters.get("location") | |||||
| api_key = self.runtime.credentials["serply_api_key"] | api_key = self.runtime.credentials["serply_api_key"] | ||||
| result = SerplyApi(api_key).run(url, location=location) | result = SerplyApi(api_key).run(url, location=location) |
| f"Employer: {job['employer']}", | f"Employer: {job['employer']}", | ||||
| f"Location: {job['location']}", | f"Location: {job['location']}", | ||||
| f"Link: {job['link']}", | f"Link: {job['link']}", | ||||
| f"""Highest: {", ".join([h for h in job["highlights"]])}""", | |||||
| f"""Highest: {", ".join(list(job["highlights"]))}""", | |||||
| "---", | "---", | ||||
| ]) | ]) | ||||
| ) | ) | ||||
| query = tool_parameters["query"] | query = tool_parameters["query"] | ||||
| gl = tool_parameters.get("gl", "us") | gl = tool_parameters.get("gl", "us") | ||||
| hl = tool_parameters.get("hl", "en") | hl = tool_parameters.get("hl", "en") | ||||
| location = tool_parameters.get("location", None) | |||||
| location = tool_parameters.get("location") | |||||
| api_key = self.runtime.credentials["serply_api_key"] | api_key = self.runtime.credentials["serply_api_key"] | ||||
| result = SerplyApi(api_key).run(query, gl=gl, hl=hl, location=location) | result = SerplyApi(api_key).run(query, gl=gl, hl=hl, location=location) |
| query = tool_parameters["query"] | query = tool_parameters["query"] | ||||
| gl = tool_parameters.get("gl", "us") | gl = tool_parameters.get("gl", "us") | ||||
| hl = tool_parameters.get("hl", "en") | hl = tool_parameters.get("hl", "en") | ||||
| location = tool_parameters.get("location", None) | |||||
| location = tool_parameters.get("location") | |||||
| api_key = self.runtime.credentials["serply_api_key"] | api_key = self.runtime.credentials["serply_api_key"] | ||||
| result = SerplyApi(api_key).run(query, gl=gl, hl=hl, location=location) | result = SerplyApi(api_key).run(query, gl=gl, hl=hl, location=location) |
| query = tool_parameters["query"] | query = tool_parameters["query"] | ||||
| gl = tool_parameters.get("gl", "us") | gl = tool_parameters.get("gl", "us") | ||||
| hl = tool_parameters.get("hl", "en") | hl = tool_parameters.get("hl", "en") | ||||
| location = tool_parameters.get("location", None) | |||||
| location = tool_parameters.get("location") | |||||
| api_key = self.runtime.credentials["serply_api_key"] | api_key = self.runtime.credentials["serply_api_key"] | ||||
| result = SerplyApi(api_key).run(query, gl=gl, hl=hl, location=location) | result = SerplyApi(api_key).run(query, gl=gl, hl=hl, location=location) |
| super().__init__(**{ | super().__init__(**{ | ||||
| 'identity': provider_yaml['identity'], | 'identity': provider_yaml['identity'], | ||||
| 'credentials_schema': provider_yaml['credentials_for_provider'] if 'credentials_for_provider' in provider_yaml else None, | |||||
| 'credentials_schema': provider_yaml.get('credentials_for_provider', None), | |||||
| }) | }) | ||||
| def _get_builtin_tools(self) -> list[Tool]: | def _get_builtin_tools(self) -> list[Tool]: |
| for content_type in self.api_bundle.openapi['requestBody']['content']: | for content_type in self.api_bundle.openapi['requestBody']['content']: | ||||
| headers['Content-Type'] = content_type | headers['Content-Type'] = content_type | ||||
| body_schema = self.api_bundle.openapi['requestBody']['content'][content_type]['schema'] | body_schema = self.api_bundle.openapi['requestBody']['content'][content_type]['schema'] | ||||
| required = body_schema['required'] if 'required' in body_schema else [] | |||||
| properties = body_schema['properties'] if 'properties' in body_schema else {} | |||||
| required = body_schema.get('required', []) | |||||
| properties = body_schema.get('properties', {}) | |||||
| for name, property in properties.items(): | for name, property in properties.items(): | ||||
| if name in parameters: | if name in parameters: | ||||
| # convert type | # convert type |
| """ | """ | ||||
| invoke dataset retriever tool | invoke dataset retriever tool | ||||
| """ | """ | ||||
| query = tool_parameters.get('query', None) | |||||
| query = tool_parameters.get('query') | |||||
| if not query: | if not query: | ||||
| return self.create_text_message(text='please input query') | return self.create_text_message(text='please input query') | ||||
| if parameter_rule.type == ToolParameter.ToolParameterType.SELECT: | if parameter_rule.type == ToolParameter.ToolParameterType.SELECT: | ||||
| # check if tool_parameter_config in options | # check if tool_parameter_config in options | ||||
| options = list(map(lambda x: x.value, parameter_rule.options)) | |||||
| options = [x.value for x in parameter_rule.options] | |||||
| if parameter_value is not None and parameter_value not in options: | if parameter_value is not None and parameter_value not in options: | ||||
| raise ValueError( | raise ValueError( | ||||
| f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}") | f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}") |
| extra_info = extra_info if extra_info is not None else {} | extra_info = extra_info if extra_info is not None else {} | ||||
| # set description to extra_info | # set description to extra_info | ||||
| if 'description' in openapi['info']: | |||||
| extra_info['description'] = openapi['info']['description'] | |||||
| else: | |||||
| extra_info['description'] = '' | |||||
| extra_info['description'] = openapi['info'].get('description', '') | |||||
| if len(openapi['servers']) == 0: | if len(openapi['servers']) == 0: | ||||
| raise ToolProviderNotFoundError('No server found in the openapi yaml.') | raise ToolProviderNotFoundError('No server found in the openapi yaml.') | ||||
| # parse body parameters | # parse body parameters | ||||
| if 'schema' in interface['operation']['requestBody']['content'][content_type]: | if 'schema' in interface['operation']['requestBody']['content'][content_type]: | ||||
| body_schema = interface['operation']['requestBody']['content'][content_type]['schema'] | body_schema = interface['operation']['requestBody']['content'][content_type]['schema'] | ||||
| required = body_schema['required'] if 'required' in body_schema else [] | |||||
| properties = body_schema['properties'] if 'properties' in body_schema else {} | |||||
| required = body_schema.get('required', []) | |||||
| properties = body_schema.get('properties', {}) | |||||
| for name, property in properties.items(): | for name, property in properties.items(): | ||||
| tool = ToolParameter( | tool = ToolParameter( | ||||
| name=name, | name=name, | ||||
| zh_Hans=name | zh_Hans=name | ||||
| ), | ), | ||||
| human_description=I18nObject( | human_description=I18nObject( | ||||
| en_US=property['description'] if 'description' in property else '', | |||||
| zh_Hans=property['description'] if 'description' in property else '' | |||||
| en_US=property.get('description', ''), | |||||
| zh_Hans=property.get('description', '') | |||||
| ), | ), | ||||
| type=ToolParameter.ToolParameterType.STRING, | type=ToolParameter.ToolParameterType.STRING, | ||||
| required=name in required, | required=name in required, | ||||
| form=ToolParameter.ToolParameterForm.LLM, | form=ToolParameter.ToolParameterForm.LLM, | ||||
| llm_description=property['description'] if 'description' in property else '', | |||||
| default=property['default'] if 'default' in property else None, | |||||
| llm_description=property.get('description', ''), | |||||
| default=property.get('default', None), | |||||
| ) | ) | ||||
| # check if there is a type | # check if there is a type | ||||
| server_url=server_url + interface['path'], | server_url=server_url + interface['path'], | ||||
| method=interface['method'], | method=interface['method'], | ||||
| summary=interface['operation']['description'] if 'description' in interface['operation'] else | summary=interface['operation']['description'] if 'description' in interface['operation'] else | ||||
| interface['operation']['summary'] if 'summary' in interface['operation'] else None, | |||||
| interface['operation'].get('summary', None), | |||||
| operation_id=interface['operation']['operationId'], | operation_id=interface['operation']['operationId'], | ||||
| parameters=parameters, | parameters=parameters, | ||||
| author='', | author='', |
| # [Cn]: Other, Not Assigned | # [Cn]: Other, Not Assigned | ||||
| # [Co]: Other, Private Use | # [Co]: Other, Private Use | ||||
| # [Cs]: Other, Surrogate | # [Cs]: Other, Surrogate | ||||
| control_chars = set(['Cc', 'Cf', 'Cn', 'Co', 'Cs']) | |||||
| control_chars = {'Cc', 'Cf', 'Cn', 'Co', 'Cs'} | |||||
| retained_chars = ['\t', '\n', '\r', '\f'] | retained_chars = ['\t', '\n', '\r', '\f'] | ||||
| # Remove non-printing control characters | # Remove non-printing control characters |
| # fetch memory | # fetch memory | ||||
| memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) | memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) | ||||
| if set(model_schema.features or []) & set([ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]) \ | |||||
| if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \ | |||||
| and node_data.reasoning_mode == 'function_call': | and node_data.reasoning_mode == 'function_call': | ||||
| # use function call | # use function call | ||||
| prompt_messages, prompt_message_tools = self._generate_function_call_prompt( | prompt_messages, prompt_message_tools = self._generate_function_call_prompt( | ||||
| if not model_schema: | if not model_schema: | ||||
| raise ValueError("Model schema not found") | raise ValueError("Model schema not found") | ||||
| if set(model_schema.features or []) & set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]): | |||||
| if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: | |||||
| prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) | prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) | ||||
| else: | else: | ||||
| prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000) | prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000) |
| } | } | ||||
| response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) | response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) | ||||
| response_json = response.json() | response_json = response.json() | ||||
| if 'results' in response_json: | |||||
| results = response_json['results'] | |||||
| else: | |||||
| results = [] | |||||
| results = response_json.get('results', []) | |||||
| return results | return results | ||||
| def notion_block_parent_page_id(self, access_token: str, block_id: str): | def notion_block_parent_page_id(self, access_token: str, block_id: str): | ||||
| } | } | ||||
| response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) | response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) | ||||
| response_json = response.json() | response_json = response.json() | ||||
| if 'results' in response_json: | |||||
| results = response_json['results'] | |||||
| else: | |||||
| results = [] | |||||
| results = response_json.get('results', []) | |||||
| return results | return results |
| preview = true | preview = true | ||||
| select = [ | select = [ | ||||
| "B", # flake8-bugbear rules | "B", # flake8-bugbear rules | ||||
| "C4", # flake8-comprehensions | |||||
| "F", # pyflakes rules | "F", # pyflakes rules | ||||
| "I", # isort rules | "I", # isort rules | ||||
| "UP", # pyupgrade rules | |||||
| "UP", # pyupgrade rules | |||||
| "B035", # static-key-dict-comprehension | |||||
| "E101", # mixed-spaces-and-tabs | "E101", # mixed-spaces-and-tabs | ||||
| "E111", # indentation-with-invalid-multiple | "E111", # indentation-with-invalid-multiple | ||||
| "E112", # no-indented-block | "E112", # no-indented-block | ||||
| "RUF100", # unused-noqa | "RUF100", # unused-noqa | ||||
| "RUF101", # redirected-noqa | "RUF101", # redirected-noqa | ||||
| "S506", # unsafe-yaml-load | "S506", # unsafe-yaml-load | ||||
| "SIM116", # if-else-block-instead-of-dict-lookup | |||||
| "SIM401", # if-else-block-instead-of-dict-get | |||||
| "SIM910", # dict-get-with-none-default | |||||
| "W191", # tab-indentation | "W191", # tab-indentation | ||||
| "W605", # invalid-escape-sequence | "W605", # invalid-escape-sequence | ||||
| "F601", # multi-value-repeated-key-literal | |||||
| "F602", # multi-value-repeated-key-variable | |||||
| ] | ] | ||||
| ignore = [ | ignore = [ | ||||
| "F403", # undefined-local-with-import-star | "F403", # undefined-local-with-import-star | ||||
| HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL = "c" | HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL = "c" | ||||
| MOCK_SWITCH = "true" | MOCK_SWITCH = "true" | ||||
| CODE_MAX_STRING_LENGTH = "80000" | CODE_MAX_STRING_LENGTH = "80000" | ||||
| CODE_EXECUTION_ENDPOINT="http://127.0.0.1:8194" | |||||
| CODE_EXECUTION_API_KEY="dify-sandbox" | |||||
| CODE_EXECUTION_ENDPOINT = "http://127.0.0.1:8194" | |||||
| CODE_EXECUTION_API_KEY = "dify-sandbox" | |||||
| FIRECRAWL_API_KEY = "fc-" | FIRECRAWL_API_KEY = "fc-" | ||||
| [tool.poetry] | [tool.poetry] | ||||
| weaviate-client = "~3.21.0" | weaviate-client = "~3.21.0" | ||||
| mailchimp-transactional = "~1.0.50" | mailchimp-transactional = "~1.0.50" | ||||
| scikit-learn = "1.2.2" | scikit-learn = "1.2.2" | ||||
| sentry-sdk = {version = "~1.39.2", extras = ["flask"]} | |||||
| sentry-sdk = { version = "~1.39.2", extras = ["flask"] } | |||||
| sympy = "1.12" | sympy = "1.12" | ||||
| jieba = "0.42.1" | jieba = "0.42.1" | ||||
| celery = "~5.3.6" | celery = "~5.3.6" | ||||
| redis = {version = "~5.0.3", extras = ["hiredis"]} | |||||
| redis = { version = "~5.0.3", extras = ["hiredis"] } | |||||
| chardet = "~5.1.0" | chardet = "~5.1.0" | ||||
| python-docx = "~1.1.0" | python-docx = "~1.1.0" | ||||
| pypdfium2 = "~4.17.0" | pypdfium2 = "~4.17.0" | ||||
| google-cloud-storage = "2.16.0" | google-cloud-storage = "2.16.0" | ||||
| replicate = "~0.22.0" | replicate = "~0.22.0" | ||||
| websocket-client = "~1.7.0" | websocket-client = "~1.7.0" | ||||
| dashscope = {version = "~1.17.0", extras = ["tokenizer"]} | |||||
| dashscope = { version = "~1.17.0", extras = ["tokenizer"] } | |||||
| huggingface-hub = "~0.16.4" | huggingface-hub = "~0.16.4" | ||||
| transformers = "~4.35.0" | transformers = "~4.35.0" | ||||
| tokenizers = "~0.15.0" | tokenizers = "~0.15.0" | ||||
| cohere = "~5.2.4" | cohere = "~5.2.4" | ||||
| pyyaml = "~6.0.1" | pyyaml = "~6.0.1" | ||||
| numpy = "~1.26.4" | numpy = "~1.26.4" | ||||
| unstructured = {version = "~0.10.27", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"]} | |||||
| unstructured = { version = "~0.10.27", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] } | |||||
| bs4 = "~0.0.1" | bs4 = "~0.0.1" | ||||
| markdown = "~3.5.1" | markdown = "~3.5.1" | ||||
| httpx = {version = "~0.27.0", extras = ["socks"]} | |||||
| httpx = { version = "~0.27.0", extras = ["socks"] } | |||||
| matplotlib = "~3.8.2" | matplotlib = "~3.8.2" | ||||
| yfinance = "~0.2.40" | yfinance = "~0.2.40" | ||||
| pydub = "~0.25.1" | pydub = "~0.25.1" | ||||
| pymysql = "1.1.1" | pymysql = "1.1.1" | ||||
| tidb-vector = "0.0.9" | tidb-vector = "0.0.9" | ||||
| google-cloud-aiplatform = "1.49.0" | google-cloud-aiplatform = "1.49.0" | ||||
| vanna = {version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"]} | |||||
| vanna = { version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"] } | |||||
| kaleido = "0.2.1" | kaleido = "0.2.1" | ||||
| tencentcloud-sdk-python-hunyuan = "~3.0.1158" | tencentcloud-sdk-python-hunyuan = "~3.0.1158" | ||||
| tcvectordb = "1.3.2" | tcvectordb = "1.3.2" |
| elif document_data["data_source"]["type"] == "notion_import": | elif document_data["data_source"]["type"] == "notion_import": | ||||
| notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] | notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] | ||||
| exist_page_ids = [] | exist_page_ids = [] | ||||
| exist_document = dict() | |||||
| exist_document = {} | |||||
| documents = Document.query.filter_by( | documents = Document.query.filter_by( | ||||
| dataset_id=dataset.id, | dataset_id=dataset.id, | ||||
| tenant_id=current_user.current_tenant_id, | tenant_id=current_user.current_tenant_id, |
| categories.add(recommended_app.category) # add category to categories | categories.add(recommended_app.category) # add category to categories | ||||
| return {'recommended_apps': recommended_apps_result, 'categories': sorted(list(categories))} | |||||
| return {'recommended_apps': recommended_apps_result, 'categories': sorted(categories)} | |||||
| @classmethod | @classmethod | ||||
| def _fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: | def _fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: |
| prompt_rules = prompt_template_config['prompt_rules'] | prompt_rules = prompt_template_config['prompt_rules'] | ||||
| role_prefix = { | role_prefix = { | ||||
| "user": prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', | |||||
| "assistant": prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' | |||||
| "user": prompt_rules.get('human_prefix', 'Human'), | |||||
| "assistant": prompt_rules.get('assistant_prefix', 'Assistant') | |||||
| } | } | ||||
| else: | else: | ||||
| advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template | advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template |
| # Mock db.session.close() | # Mock db.session.close() | ||||
| db.session.close = MagicMock() | db.session.close = MagicMock() | ||||
| node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config])) | |||||
| node._fetch_model_config = MagicMock(return_value=(model_instance, model_config)) | |||||
| # execute node | # execute node | ||||
| result = node.run(pool) | result = node.run(pool) | ||||
| # Mock db.session.close() | # Mock db.session.close() | ||||
| db.session.close = MagicMock() | db.session.close = MagicMock() | ||||
| node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config])) | |||||
| node._fetch_model_config = MagicMock(return_value=(model_instance, model_config)) | |||||
| # execute node | # execute node | ||||
| result = node.run(pool) | result = node.run(pool) |
| provider_model_bundle=provider_model_bundle | provider_model_bundle=provider_model_bundle | ||||
| ) | ) | ||||
| return MagicMock(return_value=tuple([model_instance, model_config])) | |||||
| return MagicMock(return_value=(model_instance, model_config)) | |||||
| @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) | @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) | ||||
| def test_function_calling_parameter_extractor(setup_openai_mock): | def test_function_calling_parameter_extractor(setup_openai_mock): |
| prompt_rules = prompt_template['prompt_rules'] | prompt_rules = prompt_template['prompt_rules'] | ||||
| full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text( | full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text( | ||||
| max_token_limit=2000, | max_token_limit=2000, | ||||
| human_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', | |||||
| ai_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' | |||||
| human_prefix=prompt_rules.get("human_prefix", "Human"), | |||||
| ai_prefix=prompt_rules.get("assistant_prefix", "Assistant") | |||||
| )} | )} | ||||
| real_prompt = prompt_template['prompt_template'].format(full_inputs) | real_prompt = prompt_template['prompt_template'].format(full_inputs) | ||||