You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_migration.py 9.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. import json
  2. import logging
  3. import click
  4. from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID
  5. from models.engine import db
  6. logger = logging.getLogger(__name__)
  7. class PluginDataMigration:
  8. @classmethod
  9. def migrate(cls) -> None:
  10. cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table
  11. cls.migrate_db_records("provider_models", "provider_name", ModelProviderID)
  12. cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID)
  13. cls.migrate_db_records("tenant_default_models", "provider_name", ModelProviderID)
  14. cls.migrate_db_records("tenant_preferred_model_providers", "provider_name", ModelProviderID)
  15. cls.migrate_db_records("provider_model_settings", "provider_name", ModelProviderID)
  16. cls.migrate_db_records("load_balancing_model_configs", "provider_name", ModelProviderID)
  17. cls.migrate_datasets()
  18. cls.migrate_db_records("embeddings", "provider_name", ModelProviderID) # large table
  19. cls.migrate_db_records("dataset_collection_bindings", "provider_name", ModelProviderID)
  20. cls.migrate_db_records("tool_builtin_providers", "provider", ToolProviderID)
  21. @classmethod
  22. def migrate_datasets(cls) -> None:
  23. table_name = "datasets"
  24. provider_column_name = "embedding_model_provider"
  25. click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
  26. processed_count = 0
  27. failed_ids = []
  28. while True:
  29. sql = f"""select id, {provider_column_name} as provider_name, retrieval_model from {table_name}
  30. where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
  31. limit 1000"""
  32. with db.engine.begin() as conn:
  33. rs = conn.execute(db.text(sql))
  34. current_iter_count = 0
  35. for i in rs:
  36. record_id = str(i.id)
  37. provider_name = str(i.provider_name)
  38. retrieval_model = i.retrieval_model
  39. print(type(retrieval_model))
  40. if record_id in failed_ids:
  41. continue
  42. retrieval_model_changed = False
  43. if retrieval_model:
  44. if (
  45. "reranking_model" in retrieval_model
  46. and "reranking_provider_name" in retrieval_model["reranking_model"]
  47. and retrieval_model["reranking_model"]["reranking_provider_name"]
  48. and "/" not in retrieval_model["reranking_model"]["reranking_provider_name"]
  49. ):
  50. click.echo(
  51. click.style(
  52. f"[{processed_count}] Migrating {table_name} {record_id} "
  53. f"(reranking_provider_name: "
  54. f"{retrieval_model['reranking_model']['reranking_provider_name']})",
  55. fg="white",
  56. )
  57. )
  58. # update google to langgenius/gemini/google etc.
  59. retrieval_model["reranking_model"]["reranking_provider_name"] = ModelProviderID(
  60. retrieval_model["reranking_model"]["reranking_provider_name"]
  61. ).to_string()
  62. retrieval_model_changed = True
  63. click.echo(
  64. click.style(
  65. f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
  66. fg="white",
  67. )
  68. )
  69. try:
  70. # update provider name append with "langgenius/{provider_name}/{provider_name}"
  71. params = {"record_id": record_id}
  72. update_retrieval_model_sql = ""
  73. if retrieval_model and retrieval_model_changed:
  74. update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
  75. params["retrieval_model"] = json.dumps(retrieval_model)
  76. params["provider_name"] = ModelProviderID(provider_name).to_string()
  77. sql = f"""update {table_name}
  78. set {provider_column_name} =
  79. :provider_name
  80. {update_retrieval_model_sql}
  81. where id = :record_id"""
  82. conn.execute(db.text(sql), params)
  83. click.echo(
  84. click.style(
  85. f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
  86. fg="green",
  87. )
  88. )
  89. except Exception:
  90. failed_ids.append(record_id)
  91. click.echo(
  92. click.style(
  93. f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
  94. fg="red",
  95. )
  96. )
  97. logger.exception(
  98. f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
  99. )
  100. continue
  101. current_iter_count += 1
  102. processed_count += 1
  103. if not current_iter_count:
  104. break
  105. click.echo(
  106. click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
  107. )
  108. @classmethod
  109. def migrate_db_records(
  110. cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]
  111. ) -> None:
  112. click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
  113. processed_count = 0
  114. failed_ids = []
  115. last_id = "00000000-0000-0000-0000-000000000000"
  116. while True:
  117. sql = f"""
  118. SELECT id, {provider_column_name} AS provider_name
  119. FROM {table_name}
  120. WHERE {provider_column_name} NOT LIKE '%/%'
  121. AND {provider_column_name} IS NOT NULL
  122. AND {provider_column_name} != ''
  123. AND id > :last_id
  124. ORDER BY id ASC
  125. LIMIT 5000
  126. """
  127. params = {"last_id": last_id or ""}
  128. with db.engine.begin() as conn:
  129. rs = conn.execute(db.text(sql), params)
  130. current_iter_count = 0
  131. batch_updates = []
  132. for i in rs:
  133. current_iter_count += 1
  134. processed_count += 1
  135. record_id = str(i.id)
  136. last_id = record_id
  137. provider_name = str(i.provider_name)
  138. if record_id in failed_ids:
  139. continue
  140. click.echo(
  141. click.style(
  142. f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
  143. fg="white",
  144. )
  145. )
  146. try:
  147. # update jina to langgenius/jina_tool/jina etc.
  148. updated_value = provider_cls(provider_name).to_string()
  149. batch_updates.append((updated_value, record_id))
  150. except Exception as e:
  151. failed_ids.append(record_id)
  152. click.echo(
  153. click.style(
  154. f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
  155. fg="red",
  156. )
  157. )
  158. logger.exception(
  159. f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})"
  160. )
  161. continue
  162. if batch_updates:
  163. update_sql = f"""
  164. UPDATE {table_name}
  165. SET {provider_column_name} = :updated_value
  166. WHERE id = :record_id
  167. """
  168. conn.execute(db.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates])
  169. click.echo(
  170. click.style(
  171. f"[{processed_count}] Batch migrated [{len(batch_updates)}] records from [{table_name}]",
  172. fg="green",
  173. )
  174. )
  175. if not current_iter_count:
  176. break
  177. click.echo(
  178. click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
  179. )