| @@ -58,7 +58,8 @@ class HostingConfiguration: | |||
| self.moderation_config = self.init_moderation_config(config) | |||
| def init_azure_openai(self, app_config: Config) -> HostingProvider: | |||
| @staticmethod | |||
| def init_azure_openai(app_config: Config) -> HostingProvider: | |||
| quota_unit = QuotaUnit.TIMES | |||
| if app_config.get("HOSTED_AZURE_OPENAI_ENABLED"): | |||
| credentials = { | |||
| @@ -145,7 +146,8 @@ class HostingConfiguration: | |||
| quota_unit=quota_unit, | |||
| ) | |||
| def init_anthropic(self, app_config: Config) -> HostingProvider: | |||
| @staticmethod | |||
| def init_anthropic(app_config: Config) -> HostingProvider: | |||
| quota_unit = QuotaUnit.TOKENS | |||
| quotas = [] | |||
| @@ -180,7 +182,8 @@ class HostingConfiguration: | |||
| quota_unit=quota_unit, | |||
| ) | |||
| def init_minimax(self, app_config: Config) -> HostingProvider: | |||
| @staticmethod | |||
| def init_minimax(app_config: Config) -> HostingProvider: | |||
| quota_unit = QuotaUnit.TOKENS | |||
| if app_config.get("HOSTED_MINIMAX_ENABLED"): | |||
| quotas = [FreeHostingQuota()] | |||
| @@ -197,7 +200,8 @@ class HostingConfiguration: | |||
| quota_unit=quota_unit, | |||
| ) | |||
| def init_spark(self, app_config: Config) -> HostingProvider: | |||
| @staticmethod | |||
| def init_spark(app_config: Config) -> HostingProvider: | |||
| quota_unit = QuotaUnit.TOKENS | |||
| if app_config.get("HOSTED_SPARK_ENABLED"): | |||
| quotas = [FreeHostingQuota()] | |||
| @@ -214,7 +218,8 @@ class HostingConfiguration: | |||
| quota_unit=quota_unit, | |||
| ) | |||
| def init_zhipuai(self, app_config: Config) -> HostingProvider: | |||
| @staticmethod | |||
| def init_zhipuai(app_config: Config) -> HostingProvider: | |||
| quota_unit = QuotaUnit.TOKENS | |||
| if app_config.get("HOSTED_ZHIPUAI_ENABLED"): | |||
| quotas = [FreeHostingQuota()] | |||
| @@ -231,7 +236,8 @@ class HostingConfiguration: | |||
| quota_unit=quota_unit, | |||
| ) | |||
| def init_moderation_config(self, app_config: Config) -> HostedModerationConfig: | |||
| @staticmethod | |||
| def init_moderation_config(app_config: Config) -> HostedModerationConfig: | |||
| if app_config.get("HOSTED_MODERATION_ENABLED") \ | |||
| and app_config.get("HOSTED_MODERATION_PROVIDERS"): | |||
| return HostedModerationConfig( | |||
| @@ -411,7 +411,8 @@ class IndexingRunner: | |||
| return text_docs | |||
| def filter_string(self, text): | |||
| @staticmethod | |||
| def filter_string(text): | |||
| text = re.sub(r'<\|', '<', text) | |||
| text = re.sub(r'\|>', '>', text) | |||
| text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text) | |||
| @@ -419,7 +420,8 @@ class IndexingRunner: | |||
| text = re.sub('\uFFFE', '', text) | |||
| return text | |||
| def _get_splitter(self, processing_rule: DatasetProcessRule, | |||
| @staticmethod | |||
| def _get_splitter(processing_rule: DatasetProcessRule, | |||
| embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: | |||
| """ | |||
| Get the NodeParser object according to the processing rule. | |||
| @@ -611,7 +613,8 @@ class IndexingRunner: | |||
| return all_documents | |||
| def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str: | |||
| @staticmethod | |||
| def _document_clean(text: str, processing_rule: DatasetProcessRule) -> str: | |||
| """ | |||
| Clean the document text according to the processing rules. | |||
| """ | |||
| @@ -640,7 +643,8 @@ class IndexingRunner: | |||
| return text | |||
| def format_split_text(self, text): | |||
| @staticmethod | |||
| def format_split_text(text): | |||
| regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" | |||
| matches = re.findall(regex, text, re.UNICODE) | |||
| @@ -704,7 +708,8 @@ class IndexingRunner: | |||
| } | |||
| ) | |||
| def _process_keyword_index(self, flask_app, dataset_id, document_id, documents): | |||
| @staticmethod | |||
| def _process_keyword_index(flask_app, dataset_id, document_id, documents): | |||
| with flask_app.app_context(): | |||
| dataset = Dataset.query.filter_by(id=dataset_id).first() | |||
| if not dataset: | |||
| @@ -758,13 +763,15 @@ class IndexingRunner: | |||
| return tokens | |||
| def _check_document_paused_status(self, document_id: str): | |||
| @staticmethod | |||
| def _check_document_paused_status(document_id: str): | |||
| indexing_cache_key = 'document_{}_is_paused'.format(document_id) | |||
| result = redis_client.get(indexing_cache_key) | |||
| if result: | |||
| raise DocumentIsPausedException() | |||
| def _update_document_index_status(self, document_id: str, after_indexing_status: str, | |||
| @staticmethod | |||
| def _update_document_index_status(document_id: str, after_indexing_status: str, | |||
| extra_update_params: Optional[dict] = None) -> None: | |||
| """ | |||
| Update the document indexing status. | |||
| @@ -786,14 +793,16 @@ class IndexingRunner: | |||
| DatasetDocument.query.filter_by(id=document_id).update(update_params) | |||
| db.session.commit() | |||
| def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None: | |||
| @staticmethod | |||
| def _update_segments_by_document(dataset_document_id: str, update_params: dict) -> None: | |||
| """ | |||
| Update the document segment by document id. | |||
| """ | |||
| DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) | |||
| db.session.commit() | |||
| def batch_add_segments(self, segments: list[DocumentSegment], dataset: Dataset): | |||
| @staticmethod | |||
| def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset): | |||
| """ | |||
| Batch add segments index processing | |||
| """ | |||
| @@ -44,7 +44,8 @@ class ModelInstance: | |||
| credentials=self.credentials | |||
| ) | |||
| def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict: | |||
| @staticmethod | |||
| def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict: | |||
| """ | |||
| Fetch credentials from provider model bundle | |||
| :param provider_model_bundle: provider model bundle | |||
| @@ -63,7 +64,8 @@ class ModelInstance: | |||
| return credentials | |||
| def _get_load_balancing_manager(self, configuration: ProviderConfiguration, | |||
| @staticmethod | |||
| def _get_load_balancing_manager(configuration: ProviderConfiguration, | |||
| model_type: ModelType, | |||
| model: str, | |||
| credentials: dict) -> Optional["LBModelManager"]: | |||
| @@ -515,8 +517,8 @@ class LBModelManager: | |||
| res = cast(bool, res) | |||
| return res | |||
| @classmethod | |||
| def get_config_in_cooldown_and_ttl(cls, tenant_id: str, | |||
| @staticmethod | |||
| def get_config_in_cooldown_and_ttl(tenant_id: str, | |||
| provider: str, | |||
| model_type: ModelType, | |||
| model: str, | |||
| @@ -350,7 +350,8 @@ class ProviderManager: | |||
| return default_model | |||
| def _get_all_providers(self, tenant_id: str) -> dict[str, list[Provider]]: | |||
| @staticmethod | |||
| def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]: | |||
| """ | |||
| Get all provider records of the workspace. | |||
| @@ -369,7 +370,8 @@ class ProviderManager: | |||
| return provider_name_to_provider_records_dict | |||
| def _get_all_provider_models(self, tenant_id: str) -> dict[str, list[ProviderModel]]: | |||
| @staticmethod | |||
| def _get_all_provider_models(tenant_id: str) -> dict[str, list[ProviderModel]]: | |||
| """ | |||
| Get all provider model records of the workspace. | |||
| @@ -389,7 +391,8 @@ class ProviderManager: | |||
| return provider_name_to_provider_model_records_dict | |||
| def _get_all_preferred_model_providers(self, tenant_id: str) -> dict[str, TenantPreferredModelProvider]: | |||
| @staticmethod | |||
| def _get_all_preferred_model_providers(tenant_id: str) -> dict[str, TenantPreferredModelProvider]: | |||
| """ | |||
| Get All preferred provider types of the workspace. | |||
| @@ -408,7 +411,8 @@ class ProviderManager: | |||
| return provider_name_to_preferred_provider_type_records_dict | |||
| def _get_all_provider_model_settings(self, tenant_id: str) -> dict[str, list[ProviderModelSetting]]: | |||
| @staticmethod | |||
| def _get_all_provider_model_settings(tenant_id: str) -> dict[str, list[ProviderModelSetting]]: | |||
| """ | |||
| Get All provider model settings of the workspace. | |||
| @@ -427,7 +431,8 @@ class ProviderManager: | |||
| return provider_name_to_provider_model_settings_dict | |||
| def _get_all_provider_load_balancing_configs(self, tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]: | |||
| @staticmethod | |||
| def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]: | |||
| """ | |||
| Get All provider load balancing configs of the workspace. | |||
| @@ -458,7 +463,8 @@ class ProviderManager: | |||
| return provider_name_to_provider_load_balancing_model_configs_dict | |||
| def _init_trial_provider_records(self, tenant_id: str, | |||
| @staticmethod | |||
| def _init_trial_provider_records(tenant_id: str, | |||
| provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]: | |||
| """ | |||
| Initialize trial provider records if not exists. | |||
| @@ -791,7 +797,8 @@ class ProviderManager: | |||
| credentials=current_using_credentials | |||
| ) | |||
| def _choice_current_using_quota_type(self, quota_configurations: list[QuotaConfiguration]) -> ProviderQuotaType: | |||
| @staticmethod | |||
| def _choice_current_using_quota_type(quota_configurations: list[QuotaConfiguration]) -> ProviderQuotaType: | |||
| """ | |||
| Choice current using quota type. | |||
| paid quotas > provider free quotas > hosting trial quotas | |||
| @@ -818,7 +825,8 @@ class ProviderManager: | |||
| raise ValueError('No quota type available') | |||
| def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: | |||
| @staticmethod | |||
| def _extract_secret_variables(credential_form_schemas: list[CredentialFormSchema]) -> list[str]: | |||
| """ | |||
| Extract secret input form variables. | |||