| @@ -267,4 +267,13 @@ APP_MAX_ACTIVE_REQUESTS=0 | |||
| # Celery beat configuration | |||
| CELERY_BEAT_SCHEDULER_TIME=1 | |||
| CELERY_BEAT_SCHEDULER_TIME=1 | |||
| # Position configuration | |||
| POSITION_TOOL_PINS= | |||
| POSITION_TOOL_INCLUDES= | |||
| POSITION_TOOL_EXCLUDES= | |||
| POSITION_PROVIDER_PINS= | |||
| POSITION_PROVIDER_INCLUDES= | |||
| POSITION_PROVIDER_EXCLUDES= | |||
| @@ -406,6 +406,7 @@ class DataSetConfig(BaseSettings): | |||
| default=False, | |||
| ) | |||
| class WorkspaceConfig(BaseSettings): | |||
| """ | |||
| Workspace configs | |||
| @@ -442,6 +443,63 @@ class CeleryBeatConfig(BaseSettings): | |||
| ) | |||
| class PositionConfig(BaseSettings): | |||
| POSITION_PROVIDER_PINS: str = Field( | |||
| description='The heads of model providers', | |||
| default='', | |||
| ) | |||
| POSITION_PROVIDER_INCLUDES: str = Field( | |||
| description='The included model providers', | |||
| default='', | |||
| ) | |||
| POSITION_PROVIDER_EXCLUDES: str = Field( | |||
| description='The excluded model providers', | |||
| default='', | |||
| ) | |||
| POSITION_TOOL_PINS: str = Field( | |||
| description='The heads of tools', | |||
| default='', | |||
| ) | |||
| POSITION_TOOL_INCLUDES: str = Field( | |||
| description='The included tools', | |||
| default='', | |||
| ) | |||
| POSITION_TOOL_EXCLUDES: str = Field( | |||
| description='The excluded tools', | |||
| default='', | |||
| ) | |||
| @computed_field | |||
| def POSITION_PROVIDER_PINS_LIST(self) -> list[str]: | |||
| return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(',') if item.strip() != ''] | |||
| @computed_field | |||
| def POSITION_PROVIDER_INCLUDES_LIST(self) -> list[str]: | |||
| return [item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(',') if item.strip() != ''] | |||
| @computed_field | |||
| def POSITION_PROVIDER_EXCLUDES_LIST(self) -> list[str]: | |||
| return [item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(',') if item.strip() != ''] | |||
| @computed_field | |||
| def POSITION_TOOL_PINS_LIST(self) -> list[str]: | |||
| return [item.strip() for item in self.POSITION_TOOL_PINS.split(',') if item.strip() != ''] | |||
| @computed_field | |||
| def POSITION_TOOL_INCLUDES_LIST(self) -> list[str]: | |||
| return [item.strip() for item in self.POSITION_TOOL_INCLUDES.split(',') if item.strip() != ''] | |||
| @computed_field | |||
| def POSITION_TOOL_EXCLUDES_LIST(self) -> list[str]: | |||
| return [item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(',') if item.strip() != ''] | |||
| class FeatureConfig( | |||
| # place the configs in alphabet order | |||
| AppExecutionConfig, | |||
| @@ -466,6 +524,7 @@ class FeatureConfig( | |||
| UpdateConfig, | |||
| WorkflowConfig, | |||
| WorkspaceConfig, | |||
| PositionConfig, | |||
| # hosted services config | |||
| HostedServiceConfig, | |||
| @@ -3,12 +3,13 @@ from collections import OrderedDict | |||
| from collections.abc import Callable | |||
| from typing import Any | |||
| from configs import dify_config | |||
| from core.tools.utils.yaml_utils import load_yaml_file | |||
| def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> dict[str, int]: | |||
| """ | |||
| Get the mapping from name to index from a YAML file | |||
| Get the mapping from name to index from a YAML file. | |||
| :param folder_path: | |||
| :param file_name: the YAML file name, default to '_position.yaml' | |||
| :return: a dict with name as key and index as value | |||
| @@ -19,6 +20,64 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> | |||
| return {name: index for index, name in enumerate(positions)} | |||
| def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: | |||
| """ | |||
| Get the mapping for tools from name to index from a YAML file. | |||
| :param folder_path: | |||
| :param file_name: the YAML file name, default to '_position.yaml' | |||
| :return: a dict with name as key and index as value | |||
| """ | |||
| position_map = get_position_map(folder_path, file_name=file_name) | |||
| return sort_and_filter_position_map( | |||
| position_map, | |||
| pin_list=dify_config.POSITION_TOOL_PINS_LIST, | |||
| include_list=dify_config.POSITION_TOOL_INCLUDES_LIST, | |||
| exclude_list=dify_config.POSITION_TOOL_EXCLUDES_LIST | |||
| ) | |||
| def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: | |||
| """ | |||
| Get the mapping for providers from name to index from a YAML file. | |||
| :param folder_path: | |||
| :param file_name: the YAML file name, default to '_position.yaml' | |||
| :return: a dict with name as key and index as value | |||
| """ | |||
| position_map = get_position_map(folder_path, file_name=file_name) | |||
| return sort_and_filter_position_map( | |||
| position_map, | |||
| pin_list=dify_config.POSITION_PROVIDER_PINS_LIST, | |||
| include_list=dify_config.POSITION_PROVIDER_INCLUDES_LIST, | |||
| exclude_list=dify_config.POSITION_PROVIDER_EXCLUDES_LIST | |||
| ) | |||
| def sort_and_filter_position_map(original_position_map: dict[str, int], pin_list: list[str], include_list: list[str], exclude_list: list[str]) -> dict[str, int]: | |||
| """ | |||
| Sort and filter the positions | |||
| :param position_map: the position map to be sorted and filtered | |||
| :param pin_list: the list of pins to be put at the beginning | |||
| :param include_set: the set of names to be included | |||
| :param exclude_set: the set of names to be excluded | |||
| :return: the sorted and filtered position map | |||
| """ | |||
| positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x]) | |||
| include_set = set(include_list) if include_list else set(positions) | |||
| exclude_set = set(exclude_list) if exclude_list else set() | |||
| # Add pins to position map | |||
| position_map = {name: idx for idx, name in enumerate(pin_list) if name in original_position_map} | |||
| # Add remaining positions to position map, respecting include and exclude lists | |||
| start_idx = len(position_map) | |||
| for name in positions: | |||
| if name in include_set and name not in exclude_set and name not in position_map: | |||
| position_map[name] = start_idx | |||
| start_idx += 1 | |||
| return position_map | |||
| def sort_by_position_map( | |||
| position_map: dict[str, int], | |||
| data: list[Any], | |||
| @@ -35,7 +94,9 @@ def sort_by_position_map( | |||
| if not position_map or not data: | |||
| return data | |||
| return sorted(data, key=lambda x: position_map.get(name_func(x), float('inf'))) | |||
| filtered_data = [item for item in data if name_func(item) in position_map] | |||
| return sorted(filtered_data, key=lambda x: position_map.get(name_func(x), float('inf'))) | |||
| def sort_to_dict_by_position_map( | |||
| @@ -151,9 +151,9 @@ class AIModel(ABC): | |||
| os.path.join(provider_model_type_path, model_schema_yaml) | |||
| for model_schema_yaml in os.listdir(provider_model_type_path) | |||
| if not model_schema_yaml.startswith('__') | |||
| and not model_schema_yaml.startswith('_') | |||
| and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml)) | |||
| and model_schema_yaml.endswith('.yaml') | |||
| and not model_schema_yaml.startswith('_') | |||
| and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml)) | |||
| and model_schema_yaml.endswith('.yaml') | |||
| ] | |||
| # get _position.yaml file path | |||
| @@ -6,7 +6,7 @@ from typing import Optional | |||
| from pydantic import BaseModel, ConfigDict | |||
| from core.helper.module_import_helper import load_single_subclass_from_source | |||
| from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map | |||
| from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity | |||
| from core.model_runtime.model_providers.__base.model_provider import ModelProvider | |||
| @@ -234,7 +234,7 @@ class ModelProviderFactory: | |||
| ] | |||
| # get _position.yaml file path | |||
| position_map = get_position_map(model_providers_path) | |||
| position_map = get_provider_position_map(model_providers_path) | |||
| # traverse all model_provider_dir_paths | |||
| model_providers: list[ModelProviderExtension] = [] | |||
| @@ -1,6 +1,6 @@ | |||
| import os.path | |||
| from core.helper.position_helper import get_position_map, sort_by_position_map | |||
| from core.helper.position_helper import get_tool_position_map, sort_by_position_map | |||
| from core.tools.entities.api_entities import UserToolProvider | |||
| @@ -10,11 +10,11 @@ class BuiltinToolProviderSort: | |||
| @classmethod | |||
| def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: | |||
| if not cls._position: | |||
| cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..')) | |||
| cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), '..')) | |||
| def name_func(provider: UserToolProvider) -> str: | |||
| return provider.name | |||
| sorted_providers = sort_by_position_map(cls._position, providers, name_func) | |||
| return sorted_providers | |||
| return sorted_providers | |||
| @@ -2,7 +2,7 @@ from textwrap import dedent | |||
| import pytest | |||
| from core.helper.position_helper import get_position_map | |||
| from core.helper.position_helper import get_position_map, sort_and_filter_position_map | |||
| @pytest.fixture | |||
| @@ -53,3 +53,47 @@ def test_position_helper_with_all_commented(prepare_empty_commented_positions_ya | |||
| folder_path=prepare_empty_commented_positions_yaml, | |||
| file_name='example_positions_all_commented.yaml') | |||
| assert position_map == {} | |||
| def test_excluded_position_map(prepare_example_positions_yaml): | |||
| position_map = get_position_map( | |||
| folder_path=prepare_example_positions_yaml, | |||
| file_name='example_positions.yaml' | |||
| ) | |||
| pin_list = ['forth', 'first'] | |||
| include_list = [] | |||
| exclude_list = ['9999999999999'] | |||
| sorted_filtered_position_map = sort_and_filter_position_map( | |||
| original_position_map=position_map, | |||
| pin_list=pin_list, | |||
| include_list=include_list, | |||
| exclude_list=exclude_list | |||
| ) | |||
| assert sorted_filtered_position_map == { | |||
| 'forth': 0, | |||
| 'first': 1, | |||
| 'second': 2, | |||
| 'third': 3, | |||
| } | |||
| def test_included_position_map(prepare_example_positions_yaml): | |||
| position_map = get_position_map( | |||
| folder_path=prepare_example_positions_yaml, | |||
| file_name='example_positions.yaml' | |||
| ) | |||
| pin_list = ['second', 'first'] | |||
| include_list = ['first', 'second', 'third', 'forth'] | |||
| exclude_list = [] | |||
| sorted_filtered_position_map = sort_and_filter_position_map( | |||
| original_position_map=position_map, | |||
| pin_list=pin_list, | |||
| include_list=include_list, | |||
| exclude_list=exclude_list | |||
| ) | |||
| assert sorted_filtered_position_map == { | |||
| 'second': 0, | |||
| 'first': 1, | |||
| 'third': 2, | |||
| 'forth': 3, | |||
| } | |||
| @@ -695,3 +695,22 @@ COMPOSE_PROFILES=${VECTOR_STORE:-weaviate} | |||
| # ------------------------------ | |||
| EXPOSE_NGINX_PORT=80 | |||
| EXPOSE_NGINX_SSL_PORT=443 | |||
| # ---------------------------------------------------------------------------- | |||
| # ModelProvider & Tool Position Configuration | |||
| # Used to specify the model providers and tools that can be used in the app. | |||
| # ---------------------------------------------------------------------------- | |||
| # Pin, include, and exclude tools | |||
| # Use comma-separated values with no spaces between items. | |||
| # Example: POSITION_TOOL_PINS=bing,google | |||
| POSITION_TOOL_PINS= | |||
| POSITION_TOOL_INCLUDES= | |||
| POSITION_TOOL_EXCLUDES= | |||
| # Pin, include, and exclude model providers | |||
| # Use comma-separated values with no spaces between items. | |||
| # Example: POSITION_PROVIDER_PINS=openai,openllm | |||
| POSITION_PROVIDER_PINS= | |||
| POSITION_PROVIDER_INCLUDES= | |||
| POSITION_PROVIDER_EXCLUDES= | |||