| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208 | 
							- import json
 - import logging
 - 
 - import click
 - import sqlalchemy as sa
 - 
 - from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID
 - from models.engine import db
 - 
 - logger = logging.getLogger(__name__)
 - 
 - 
 - class PluginDataMigration:
 -     @classmethod
 -     def migrate(cls):
 -         cls.migrate_db_records("providers", "provider_name", ModelProviderID)  # large table
 -         cls.migrate_db_records("provider_models", "provider_name", ModelProviderID)
 -         cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID)
 -         cls.migrate_db_records("tenant_default_models", "provider_name", ModelProviderID)
 -         cls.migrate_db_records("tenant_preferred_model_providers", "provider_name", ModelProviderID)
 -         cls.migrate_db_records("provider_model_settings", "provider_name", ModelProviderID)
 -         cls.migrate_db_records("load_balancing_model_configs", "provider_name", ModelProviderID)
 -         cls.migrate_datasets()
 -         cls.migrate_db_records("embeddings", "provider_name", ModelProviderID)  # large table
 -         cls.migrate_db_records("dataset_collection_bindings", "provider_name", ModelProviderID)
 -         cls.migrate_db_records("tool_builtin_providers", "provider", ToolProviderID)
 - 
 -     @classmethod
 -     def migrate_datasets(cls):
 -         table_name = "datasets"
 -         provider_column_name = "embedding_model_provider"
 - 
 -         click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
 - 
 -         processed_count = 0
 -         failed_ids = []
 -         while True:
 -             sql = f"""select id, {provider_column_name} as provider_name, retrieval_model from {table_name}
 - where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
 - limit 1000"""
 -             with db.engine.begin() as conn:
 -                 rs = conn.execute(sa.text(sql))
 - 
 -                 current_iter_count = 0
 -                 for i in rs:
 -                     record_id = str(i.id)
 -                     provider_name = str(i.provider_name)
 -                     retrieval_model = i.retrieval_model
 -                     print(type(retrieval_model))
 - 
 -                     if record_id in failed_ids:
 -                         continue
 - 
 -                     retrieval_model_changed = False
 -                     if retrieval_model:
 -                         if (
 -                             "reranking_model" in retrieval_model
 -                             and "reranking_provider_name" in retrieval_model["reranking_model"]
 -                             and retrieval_model["reranking_model"]["reranking_provider_name"]
 -                             and "/" not in retrieval_model["reranking_model"]["reranking_provider_name"]
 -                         ):
 -                             click.echo(
 -                                 click.style(
 -                                     f"[{processed_count}] Migrating {table_name} {record_id} "
 -                                     f"(reranking_provider_name: "
 -                                     f"{retrieval_model['reranking_model']['reranking_provider_name']})",
 -                                     fg="white",
 -                                 )
 -                             )
 -                             # update google to langgenius/gemini/google etc.
 -                             retrieval_model["reranking_model"]["reranking_provider_name"] = ModelProviderID(
 -                                 retrieval_model["reranking_model"]["reranking_provider_name"]
 -                             ).to_string()
 -                             retrieval_model_changed = True
 - 
 -                     click.echo(
 -                         click.style(
 -                             f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
 -                             fg="white",
 -                         )
 -                     )
 - 
 -                     try:
 -                         # update provider name append with "langgenius/{provider_name}/{provider_name}"
 -                         params = {"record_id": record_id}
 -                         update_retrieval_model_sql = ""
 -                         if retrieval_model and retrieval_model_changed:
 -                             update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
 -                             params["retrieval_model"] = json.dumps(retrieval_model)
 - 
 -                         params["provider_name"] = ModelProviderID(provider_name).to_string()
 - 
 -                         sql = f"""update {table_name}
 -                         set {provider_column_name} =
 -                         :provider_name
 -                         {update_retrieval_model_sql}
 -                         where id = :record_id"""
 -                         conn.execute(sa.text(sql), params)
 -                         click.echo(
 -                             click.style(
 -                                 f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
 -                                 fg="green",
 -                             )
 -                         )
 -                     except Exception:
 -                         failed_ids.append(record_id)
 -                         click.echo(
 -                             click.style(
 -                                 f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
 -                                 fg="red",
 -                             )
 -                         )
 -                         logger.exception(
 -                             "[%s] Failed to migrate [%s] %s (%s)", processed_count, table_name, record_id, provider_name
 -                         )
 -                         continue
 - 
 -                     current_iter_count += 1
 -                     processed_count += 1
 - 
 -             if not current_iter_count:
 -                 break
 - 
 -         click.echo(
 -             click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
 -         )
 - 
 -     @classmethod
 -     def migrate_db_records(cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]):
 -         click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
 - 
 -         processed_count = 0
 -         failed_ids = []
 -         last_id = "00000000-0000-0000-0000-000000000000"
 - 
 -         while True:
 -             sql = f"""
 -                 SELECT id, {provider_column_name} AS provider_name
 -                 FROM {table_name}
 -                 WHERE {provider_column_name} NOT LIKE '%/%'
 -                     AND {provider_column_name} IS NOT NULL
 -                     AND {provider_column_name} != ''
 -                     AND id > :last_id
 -                 ORDER BY id ASC
 -                 LIMIT 5000
 -             """
 -             params = {"last_id": last_id or ""}
 - 
 -             with db.engine.begin() as conn:
 -                 rs = conn.execute(sa.text(sql), params)
 - 
 -                 current_iter_count = 0
 -                 batch_updates = []
 - 
 -                 for i in rs:
 -                     current_iter_count += 1
 -                     processed_count += 1
 -                     record_id = str(i.id)
 -                     last_id = record_id
 -                     provider_name = str(i.provider_name)
 - 
 -                     if record_id in failed_ids:
 -                         continue
 - 
 -                     click.echo(
 -                         click.style(
 -                             f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
 -                             fg="white",
 -                         )
 -                     )
 - 
 -                     try:
 -                         # update jina to langgenius/jina_tool/jina etc.
 -                         updated_value = provider_cls(provider_name).to_string()
 -                         batch_updates.append((updated_value, record_id))
 -                     except Exception:
 -                         failed_ids.append(record_id)
 -                         click.echo(
 -                             click.style(
 -                                 f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
 -                                 fg="red",
 -                             )
 -                         )
 -                         logger.exception(
 -                             "[%s] Failed to migrate [%s] %s (%s)", processed_count, table_name, record_id, provider_name
 -                         )
 -                         continue
 - 
 -                 if batch_updates:
 -                     update_sql = f"""
 -                         UPDATE {table_name}
 -                         SET {provider_column_name} = :updated_value
 -                         WHERE id = :record_id
 -                     """
 -                     conn.execute(sa.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates])
 -                     click.echo(
 -                         click.style(
 -                             f"[{processed_count}] Batch migrated [{len(batch_updates)}] records from [{table_name}]",
 -                             fg="green",
 -                         )
 -                     )
 - 
 -             if not current_iter_count:
 -                 break
 - 
 -         click.echo(
 -             click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
 -         )
 
 
  |