| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 |
- import json
- import logging
-
- import click
- import sqlalchemy as sa
-
- from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID
- from models.engine import db
-
- logger = logging.getLogger(__name__)
-
-
- class PluginDataMigration:
- @classmethod
- def migrate(cls) -> None:
- cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table
- cls.migrate_db_records("provider_models", "provider_name", ModelProviderID)
- cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID)
- cls.migrate_db_records("tenant_default_models", "provider_name", ModelProviderID)
- cls.migrate_db_records("tenant_preferred_model_providers", "provider_name", ModelProviderID)
- cls.migrate_db_records("provider_model_settings", "provider_name", ModelProviderID)
- cls.migrate_db_records("load_balancing_model_configs", "provider_name", ModelProviderID)
- cls.migrate_datasets()
- cls.migrate_db_records("embeddings", "provider_name", ModelProviderID) # large table
- cls.migrate_db_records("dataset_collection_bindings", "provider_name", ModelProviderID)
- cls.migrate_db_records("tool_builtin_providers", "provider", ToolProviderID)
-
- @classmethod
- def migrate_datasets(cls) -> None:
- table_name = "datasets"
- provider_column_name = "embedding_model_provider"
-
- click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
-
- processed_count = 0
- failed_ids = []
- while True:
- sql = f"""select id, {provider_column_name} as provider_name, retrieval_model from {table_name}
- where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
- limit 1000"""
- with db.engine.begin() as conn:
- rs = conn.execute(sa.text(sql))
-
- current_iter_count = 0
- for i in rs:
- record_id = str(i.id)
- provider_name = str(i.provider_name)
- retrieval_model = i.retrieval_model
- print(type(retrieval_model))
-
- if record_id in failed_ids:
- continue
-
- retrieval_model_changed = False
- if retrieval_model:
- if (
- "reranking_model" in retrieval_model
- and "reranking_provider_name" in retrieval_model["reranking_model"]
- and retrieval_model["reranking_model"]["reranking_provider_name"]
- and "/" not in retrieval_model["reranking_model"]["reranking_provider_name"]
- ):
- click.echo(
- click.style(
- f"[{processed_count}] Migrating {table_name} {record_id} "
- f"(reranking_provider_name: "
- f"{retrieval_model['reranking_model']['reranking_provider_name']})",
- fg="white",
- )
- )
- # update google to langgenius/gemini/google etc.
- retrieval_model["reranking_model"]["reranking_provider_name"] = ModelProviderID(
- retrieval_model["reranking_model"]["reranking_provider_name"]
- ).to_string()
- retrieval_model_changed = True
-
- click.echo(
- click.style(
- f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
- fg="white",
- )
- )
-
- try:
- # update provider name append with "langgenius/{provider_name}/{provider_name}"
- params = {"record_id": record_id}
- update_retrieval_model_sql = ""
- if retrieval_model and retrieval_model_changed:
- update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
- params["retrieval_model"] = json.dumps(retrieval_model)
-
- params["provider_name"] = ModelProviderID(provider_name).to_string()
-
- sql = f"""update {table_name}
- set {provider_column_name} =
- :provider_name
- {update_retrieval_model_sql}
- where id = :record_id"""
- conn.execute(sa.text(sql), params)
- click.echo(
- click.style(
- f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
- fg="green",
- )
- )
- except Exception:
- failed_ids.append(record_id)
- click.echo(
- click.style(
- f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
- fg="red",
- )
- )
- logger.exception(
- "[%s] Failed to migrate [%s] %s (%s)", processed_count, table_name, record_id, provider_name
- )
- continue
-
- current_iter_count += 1
- processed_count += 1
-
- if not current_iter_count:
- break
-
- click.echo(
- click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
- )
-
- @classmethod
- def migrate_db_records(
- cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]
- ) -> None:
- click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
-
- processed_count = 0
- failed_ids = []
- last_id = "00000000-0000-0000-0000-000000000000"
-
- while True:
- sql = f"""
- SELECT id, {provider_column_name} AS provider_name
- FROM {table_name}
- WHERE {provider_column_name} NOT LIKE '%/%'
- AND {provider_column_name} IS NOT NULL
- AND {provider_column_name} != ''
- AND id > :last_id
- ORDER BY id ASC
- LIMIT 5000
- """
- params = {"last_id": last_id or ""}
-
- with db.engine.begin() as conn:
- rs = conn.execute(sa.text(sql), params)
-
- current_iter_count = 0
- batch_updates = []
-
- for i in rs:
- current_iter_count += 1
- processed_count += 1
- record_id = str(i.id)
- last_id = record_id
- provider_name = str(i.provider_name)
-
- if record_id in failed_ids:
- continue
-
- click.echo(
- click.style(
- f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
- fg="white",
- )
- )
-
- try:
- # update jina to langgenius/jina_tool/jina etc.
- updated_value = provider_cls(provider_name).to_string()
- batch_updates.append((updated_value, record_id))
- except Exception as e:
- failed_ids.append(record_id)
- click.echo(
- click.style(
- f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
- fg="red",
- )
- )
- logger.exception(
- "[%s] Failed to migrate [%s] %s (%s)", processed_count, table_name, record_id, provider_name
- )
- continue
-
- if batch_updates:
- update_sql = f"""
- UPDATE {table_name}
- SET {provider_column_name} = :updated_value
- WHERE id = :record_id
- """
- conn.execute(sa.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates])
- click.echo(
- click.style(
- f"[{processed_count}] Batch migrated [{len(batch_updates)}] records from [{table_name}]",
- fg="green",
- )
- )
-
- if not current_iter_count:
- break
-
- click.echo(
- click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
- )
|