| @@ -3,7 +3,7 @@ import logging | |||
| import click | |||
| from core.entities import DEFAULT_PLUGIN_ID | |||
| from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID | |||
| from models.engine import db | |||
| logger = logging.getLogger(__name__) | |||
| @@ -12,17 +12,17 @@ logger = logging.getLogger(__name__) | |||
| class PluginDataMigration: | |||
| @classmethod | |||
| def migrate(cls) -> None: | |||
| cls.migrate_db_records("providers", "provider_name") # large table | |||
| cls.migrate_db_records("provider_models", "provider_name") | |||
| cls.migrate_db_records("provider_orders", "provider_name") | |||
| cls.migrate_db_records("tenant_default_models", "provider_name") | |||
| cls.migrate_db_records("tenant_preferred_model_providers", "provider_name") | |||
| cls.migrate_db_records("provider_model_settings", "provider_name") | |||
| cls.migrate_db_records("load_balancing_model_configs", "provider_name") | |||
| cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table | |||
| cls.migrate_db_records("provider_models", "provider_name", ModelProviderID) | |||
| cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID) | |||
| cls.migrate_db_records("tenant_default_models", "provider_name", ModelProviderID) | |||
| cls.migrate_db_records("tenant_preferred_model_providers", "provider_name", ModelProviderID) | |||
| cls.migrate_db_records("provider_model_settings", "provider_name", ModelProviderID) | |||
| cls.migrate_db_records("load_balancing_model_configs", "provider_name", ModelProviderID) | |||
| cls.migrate_datasets() | |||
| cls.migrate_db_records("embeddings", "provider_name") # large table | |||
| cls.migrate_db_records("dataset_collection_bindings", "provider_name") | |||
| cls.migrate_db_records("tool_builtin_providers", "provider") | |||
| cls.migrate_db_records("embeddings", "provider_name", ModelProviderID) # large table | |||
| cls.migrate_db_records("dataset_collection_bindings", "provider_name", ModelProviderID) | |||
| cls.migrate_db_records("tool_builtin_providers", "provider_name", ToolProviderID) | |||
| @classmethod | |||
| def migrate_datasets(cls) -> None: | |||
| @@ -66,9 +66,10 @@ limit 1000""" | |||
| fg="white", | |||
| ) | |||
| ) | |||
| retrieval_model["reranking_model"]["reranking_provider_name"] = ( | |||
| f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}" | |||
| ) | |||
| # update google to langgenius/gemini/google etc. | |||
| retrieval_model["reranking_model"]["reranking_provider_name"] = ModelProviderID( | |||
| retrieval_model["reranking_model"]["reranking_provider_name"] | |||
| ).to_string() | |||
| retrieval_model_changed = True | |||
| click.echo( | |||
| @@ -86,9 +87,11 @@ limit 1000""" | |||
| update_retrieval_model_sql = ", retrieval_model = :retrieval_model" | |||
| params["retrieval_model"] = json.dumps(retrieval_model) | |||
| params["provider_name"] = ModelProviderID(provider_name).to_string() | |||
| sql = f"""update {table_name} | |||
| set {provider_column_name} = | |||
| concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name}) | |||
| :provider_name | |||
| {update_retrieval_model_sql} | |||
| where id = :record_id""" | |||
| conn.execute(db.text(sql), params) | |||
| @@ -122,7 +125,9 @@ limit 1000""" | |||
| ) | |||
| @classmethod | |||
| def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None: | |||
| def migrate_db_records( | |||
| cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID] | |||
| ) -> None: | |||
| click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white")) | |||
| processed_count = 0 | |||
| @@ -166,7 +171,8 @@ limit 1000""" | |||
| ) | |||
| try: | |||
| updated_value = f"{DEFAULT_PLUGIN_ID}/{provider_name}/{provider_name}" | |||
| # update jina to langgenius/jina_tool/jina etc. | |||
| updated_value = provider_cls(provider_name).to_string() | |||
| batch_updates.append((updated_value, record_id)) | |||
| except Exception as e: | |||
| failed_ids.append(record_id) | |||