Browse Source

fix(plugin/migrations) refactor data migration to use specific provider ID classes. (#21187)

tags/1.5.0
Yeuoly 4 months ago
parent
commit
2020a31785
No account linked to committer's email address
1 changed files with 23 additions and 17 deletions
  1. 23
    17
      api/services/plugin/data_migration.py

+ 23
- 17
api/services/plugin/data_migration.py View File



import click import click


from core.entities import DEFAULT_PLUGIN_ID
from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID
from models.engine import db from models.engine import db


logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PluginDataMigration: class PluginDataMigration:
@classmethod @classmethod
def migrate(cls) -> None: def migrate(cls) -> None:
cls.migrate_db_records("providers", "provider_name") # large table
cls.migrate_db_records("provider_models", "provider_name")
cls.migrate_db_records("provider_orders", "provider_name")
cls.migrate_db_records("tenant_default_models", "provider_name")
cls.migrate_db_records("tenant_preferred_model_providers", "provider_name")
cls.migrate_db_records("provider_model_settings", "provider_name")
cls.migrate_db_records("load_balancing_model_configs", "provider_name")
cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table
cls.migrate_db_records("provider_models", "provider_name", ModelProviderID)
cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID)
cls.migrate_db_records("tenant_default_models", "provider_name", ModelProviderID)
cls.migrate_db_records("tenant_preferred_model_providers", "provider_name", ModelProviderID)
cls.migrate_db_records("provider_model_settings", "provider_name", ModelProviderID)
cls.migrate_db_records("load_balancing_model_configs", "provider_name", ModelProviderID)
cls.migrate_datasets() cls.migrate_datasets()
cls.migrate_db_records("embeddings", "provider_name") # large table
cls.migrate_db_records("dataset_collection_bindings", "provider_name")
cls.migrate_db_records("tool_builtin_providers", "provider")
cls.migrate_db_records("embeddings", "provider_name", ModelProviderID) # large table
cls.migrate_db_records("dataset_collection_bindings", "provider_name", ModelProviderID)
cls.migrate_db_records("tool_builtin_providers", "provider_name", ToolProviderID)


@classmethod @classmethod
def migrate_datasets(cls) -> None: def migrate_datasets(cls) -> None:
fg="white", fg="white",
) )
) )
retrieval_model["reranking_model"]["reranking_provider_name"] = (
f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}"
)
# update google to langgenius/gemini/google etc.
retrieval_model["reranking_model"]["reranking_provider_name"] = ModelProviderID(
retrieval_model["reranking_model"]["reranking_provider_name"]
).to_string()
retrieval_model_changed = True retrieval_model_changed = True


click.echo( click.echo(
update_retrieval_model_sql = ", retrieval_model = :retrieval_model" update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
params["retrieval_model"] = json.dumps(retrieval_model) params["retrieval_model"] = json.dumps(retrieval_model)


params["provider_name"] = ModelProviderID(provider_name).to_string()

sql = f"""update {table_name} sql = f"""update {table_name}
set {provider_column_name} = set {provider_column_name} =
concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
:provider_name
{update_retrieval_model_sql} {update_retrieval_model_sql}
where id = :record_id""" where id = :record_id"""
conn.execute(db.text(sql), params) conn.execute(db.text(sql), params)
) )


@classmethod @classmethod
def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None:
def migrate_db_records(
cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]
) -> None:
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white")) click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))


processed_count = 0 processed_count = 0
) )


try: try:
updated_value = f"{DEFAULT_PLUGIN_ID}/{provider_name}/{provider_name}"
# update jina to langgenius/jina_tool/jina etc.
updated_value = provider_cls(provider_name).to_string()
batch_updates.append((updated_value, record_id)) batch_updates.append((updated_value, record_id))
except Exception as e: except Exception as e:
failed_ids.append(record_id) failed_ids.append(record_id)

Loading…
Cancel
Save