| import click | import click | ||||
| from core.entities import DEFAULT_PLUGIN_ID | |||||
| from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID | |||||
| from models.engine import db | from models.engine import db | ||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| class PluginDataMigration: | class PluginDataMigration: | ||||
| @classmethod | @classmethod | ||||
| def migrate(cls) -> None: | def migrate(cls) -> None: | ||||
| cls.migrate_db_records("providers", "provider_name") # large table | |||||
| cls.migrate_db_records("provider_models", "provider_name") | |||||
| cls.migrate_db_records("provider_orders", "provider_name") | |||||
| cls.migrate_db_records("tenant_default_models", "provider_name") | |||||
| cls.migrate_db_records("tenant_preferred_model_providers", "provider_name") | |||||
| cls.migrate_db_records("provider_model_settings", "provider_name") | |||||
| cls.migrate_db_records("load_balancing_model_configs", "provider_name") | |||||
| cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table | |||||
| cls.migrate_db_records("provider_models", "provider_name", ModelProviderID) | |||||
| cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID) | |||||
| cls.migrate_db_records("tenant_default_models", "provider_name", ModelProviderID) | |||||
| cls.migrate_db_records("tenant_preferred_model_providers", "provider_name", ModelProviderID) | |||||
| cls.migrate_db_records("provider_model_settings", "provider_name", ModelProviderID) | |||||
| cls.migrate_db_records("load_balancing_model_configs", "provider_name", ModelProviderID) | |||||
| cls.migrate_datasets() | cls.migrate_datasets() | ||||
| cls.migrate_db_records("embeddings", "provider_name") # large table | |||||
| cls.migrate_db_records("dataset_collection_bindings", "provider_name") | |||||
| cls.migrate_db_records("tool_builtin_providers", "provider") | |||||
| cls.migrate_db_records("embeddings", "provider_name", ModelProviderID) # large table | |||||
| cls.migrate_db_records("dataset_collection_bindings", "provider_name", ModelProviderID) | |||||
| cls.migrate_db_records("tool_builtin_providers", "provider_name", ToolProviderID) | |||||
| @classmethod | @classmethod | ||||
| def migrate_datasets(cls) -> None: | def migrate_datasets(cls) -> None: | ||||
| fg="white", | fg="white", | ||||
| ) | ) | ||||
| ) | ) | ||||
| retrieval_model["reranking_model"]["reranking_provider_name"] = ( | |||||
| f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}" | |||||
| ) | |||||
| # update google to langgenius/gemini/google etc. | |||||
| retrieval_model["reranking_model"]["reranking_provider_name"] = ModelProviderID( | |||||
| retrieval_model["reranking_model"]["reranking_provider_name"] | |||||
| ).to_string() | |||||
| retrieval_model_changed = True | retrieval_model_changed = True | ||||
| click.echo( | click.echo( | ||||
| update_retrieval_model_sql = ", retrieval_model = :retrieval_model" | update_retrieval_model_sql = ", retrieval_model = :retrieval_model" | ||||
| params["retrieval_model"] = json.dumps(retrieval_model) | params["retrieval_model"] = json.dumps(retrieval_model) | ||||
| params["provider_name"] = ModelProviderID(provider_name).to_string() | |||||
| sql = f"""update {table_name} | sql = f"""update {table_name} | ||||
| set {provider_column_name} = | set {provider_column_name} = | ||||
| concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name}) | |||||
| :provider_name | |||||
| {update_retrieval_model_sql} | {update_retrieval_model_sql} | ||||
| where id = :record_id""" | where id = :record_id""" | ||||
| conn.execute(db.text(sql), params) | conn.execute(db.text(sql), params) | ||||
| ) | ) | ||||
| @classmethod | @classmethod | ||||
| def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None: | |||||
| def migrate_db_records( | |||||
| cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID] | |||||
| ) -> None: | |||||
| click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white")) | click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white")) | ||||
| processed_count = 0 | processed_count = 0 | ||||
| ) | ) | ||||
| try: | try: | ||||
| updated_value = f"{DEFAULT_PLUGIN_ID}/{provider_name}/{provider_name}" | |||||
| # update jina to langgenius/jina_tool/jina etc. | |||||
| updated_value = provider_cls(provider_name).to_string() | |||||
| batch_updates.append((updated_value, record_id)) | batch_updates.append((updated_value, record_id)) | ||||
| except Exception as e: | except Exception as e: | ||||
| failed_ids.append(record_id) | failed_ids.append(record_id) |