You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_migration.py 9.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import json
  2. import logging
  3. import click
  4. import sqlalchemy as sa
  5. from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID
  6. from models.engine import db
  7. logger = logging.getLogger(__name__)
  8. class PluginDataMigration:
  9. @classmethod
  10. def migrate(cls) -> None:
  11. cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table
  12. cls.migrate_db_records("provider_models", "provider_name", ModelProviderID)
  13. cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID)
  14. cls.migrate_db_records("tenant_default_models", "provider_name", ModelProviderID)
  15. cls.migrate_db_records("tenant_preferred_model_providers", "provider_name", ModelProviderID)
  16. cls.migrate_db_records("provider_model_settings", "provider_name", ModelProviderID)
  17. cls.migrate_db_records("load_balancing_model_configs", "provider_name", ModelProviderID)
  18. cls.migrate_datasets()
  19. cls.migrate_db_records("embeddings", "provider_name", ModelProviderID) # large table
  20. cls.migrate_db_records("dataset_collection_bindings", "provider_name", ModelProviderID)
  21. cls.migrate_db_records("tool_builtin_providers", "provider", ToolProviderID)
  22. @classmethod
  23. def migrate_datasets(cls) -> None:
  24. table_name = "datasets"
  25. provider_column_name = "embedding_model_provider"
  26. click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
  27. processed_count = 0
  28. failed_ids = []
  29. while True:
  30. sql = f"""select id, {provider_column_name} as provider_name, retrieval_model from {table_name}
  31. where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
  32. limit 1000"""
  33. with db.engine.begin() as conn:
  34. rs = conn.execute(sa.text(sql))
  35. current_iter_count = 0
  36. for i in rs:
  37. record_id = str(i.id)
  38. provider_name = str(i.provider_name)
  39. retrieval_model = i.retrieval_model
  40. print(type(retrieval_model))
  41. if record_id in failed_ids:
  42. continue
  43. retrieval_model_changed = False
  44. if retrieval_model:
  45. if (
  46. "reranking_model" in retrieval_model
  47. and "reranking_provider_name" in retrieval_model["reranking_model"]
  48. and retrieval_model["reranking_model"]["reranking_provider_name"]
  49. and "/" not in retrieval_model["reranking_model"]["reranking_provider_name"]
  50. ):
  51. click.echo(
  52. click.style(
  53. f"[{processed_count}] Migrating {table_name} {record_id} "
  54. f"(reranking_provider_name: "
  55. f"{retrieval_model['reranking_model']['reranking_provider_name']})",
  56. fg="white",
  57. )
  58. )
  59. # update google to langgenius/gemini/google etc.
  60. retrieval_model["reranking_model"]["reranking_provider_name"] = ModelProviderID(
  61. retrieval_model["reranking_model"]["reranking_provider_name"]
  62. ).to_string()
  63. retrieval_model_changed = True
  64. click.echo(
  65. click.style(
  66. f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
  67. fg="white",
  68. )
  69. )
  70. try:
  71. # update provider name append with "langgenius/{provider_name}/{provider_name}"
  72. params = {"record_id": record_id}
  73. update_retrieval_model_sql = ""
  74. if retrieval_model and retrieval_model_changed:
  75. update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
  76. params["retrieval_model"] = json.dumps(retrieval_model)
  77. params["provider_name"] = ModelProviderID(provider_name).to_string()
  78. sql = f"""update {table_name}
  79. set {provider_column_name} =
  80. :provider_name
  81. {update_retrieval_model_sql}
  82. where id = :record_id"""
  83. conn.execute(sa.text(sql), params)
  84. click.echo(
  85. click.style(
  86. f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
  87. fg="green",
  88. )
  89. )
  90. except Exception:
  91. failed_ids.append(record_id)
  92. click.echo(
  93. click.style(
  94. f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
  95. fg="red",
  96. )
  97. )
  98. logger.exception(
  99. "[%s] Failed to migrate [%s] %s (%s)", processed_count, table_name, record_id, provider_name
  100. )
  101. continue
  102. current_iter_count += 1
  103. processed_count += 1
  104. if not current_iter_count:
  105. break
  106. click.echo(
  107. click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
  108. )
  109. @classmethod
  110. def migrate_db_records(
  111. cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]
  112. ) -> None:
  113. click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
  114. processed_count = 0
  115. failed_ids = []
  116. last_id = "00000000-0000-0000-0000-000000000000"
  117. while True:
  118. sql = f"""
  119. SELECT id, {provider_column_name} AS provider_name
  120. FROM {table_name}
  121. WHERE {provider_column_name} NOT LIKE '%/%'
  122. AND {provider_column_name} IS NOT NULL
  123. AND {provider_column_name} != ''
  124. AND id > :last_id
  125. ORDER BY id ASC
  126. LIMIT 5000
  127. """
  128. params = {"last_id": last_id or ""}
  129. with db.engine.begin() as conn:
  130. rs = conn.execute(sa.text(sql), params)
  131. current_iter_count = 0
  132. batch_updates = []
  133. for i in rs:
  134. current_iter_count += 1
  135. processed_count += 1
  136. record_id = str(i.id)
  137. last_id = record_id
  138. provider_name = str(i.provider_name)
  139. if record_id in failed_ids:
  140. continue
  141. click.echo(
  142. click.style(
  143. f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
  144. fg="white",
  145. )
  146. )
  147. try:
  148. # update jina to langgenius/jina_tool/jina etc.
  149. updated_value = provider_cls(provider_name).to_string()
  150. batch_updates.append((updated_value, record_id))
  151. except Exception as e:
  152. failed_ids.append(record_id)
  153. click.echo(
  154. click.style(
  155. f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
  156. fg="red",
  157. )
  158. )
  159. logger.exception(
  160. "[%s] Failed to migrate [%s] %s (%s)", processed_count, table_name, record_id, provider_name
  161. )
  162. continue
  163. if batch_updates:
  164. update_sql = f"""
  165. UPDATE {table_name}
  166. SET {provider_column_name} = :updated_value
  167. WHERE id = :record_id
  168. """
  169. conn.execute(sa.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates])
  170. click.echo(
  171. click.style(
  172. f"[{processed_count}] Batch migrated [{len(batch_updates)}] records from [{table_name}]",
  173. fg="green",
  174. )
  175. )
  176. if not current_iter_count:
  177. break
  178. click.echo(
  179. click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
  180. )