Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

data_migration.py 9.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import json
  2. import logging
  3. import click
  4. import sqlalchemy as sa
  5. from extensions.ext_database import db
  6. from models.provider_ids import GenericProviderID, ModelProviderID, ToolProviderID
  7. logger = logging.getLogger(__name__)
  8. class PluginDataMigration:
  9. @classmethod
  10. def migrate(cls):
  11. cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table
  12. cls.migrate_db_records("provider_models", "provider_name", ModelProviderID)
  13. cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID)
  14. cls.migrate_db_records("tenant_default_models", "provider_name", ModelProviderID)
  15. cls.migrate_db_records("tenant_preferred_model_providers", "provider_name", ModelProviderID)
  16. cls.migrate_db_records("provider_model_settings", "provider_name", ModelProviderID)
  17. cls.migrate_db_records("load_balancing_model_configs", "provider_name", ModelProviderID)
  18. cls.migrate_datasets()
  19. cls.migrate_db_records("embeddings", "provider_name", ModelProviderID) # large table
  20. cls.migrate_db_records("dataset_collection_bindings", "provider_name", ModelProviderID)
  21. cls.migrate_db_records("tool_builtin_providers", "provider", ToolProviderID)
  22. @classmethod
  23. def migrate_datasets(cls):
  24. table_name = "datasets"
  25. provider_column_name = "embedding_model_provider"
  26. click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
  27. processed_count = 0
  28. failed_ids = []
  29. while True:
  30. sql = f"""select id, {provider_column_name} as provider_name, retrieval_model from {table_name}
  31. where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
  32. limit 1000"""
  33. with db.engine.begin() as conn:
  34. rs = conn.execute(sa.text(sql))
  35. current_iter_count = 0
  36. for i in rs:
  37. record_id = str(i.id)
  38. provider_name = str(i.provider_name)
  39. retrieval_model = i.retrieval_model
  40. logger.debug(
  41. "Processing dataset %s with retrieval model of type %s",
  42. record_id,
  43. type(retrieval_model),
  44. )
  45. if record_id in failed_ids:
  46. continue
  47. retrieval_model_changed = False
  48. if retrieval_model:
  49. if (
  50. "reranking_model" in retrieval_model
  51. and "reranking_provider_name" in retrieval_model["reranking_model"]
  52. and retrieval_model["reranking_model"]["reranking_provider_name"]
  53. and "/" not in retrieval_model["reranking_model"]["reranking_provider_name"]
  54. ):
  55. click.echo(
  56. click.style(
  57. f"[{processed_count}] Migrating {table_name} {record_id} "
  58. f"(reranking_provider_name: "
  59. f"{retrieval_model['reranking_model']['reranking_provider_name']})",
  60. fg="white",
  61. )
  62. )
  63. # update google to langgenius/gemini/google etc.
  64. retrieval_model["reranking_model"]["reranking_provider_name"] = ModelProviderID(
  65. retrieval_model["reranking_model"]["reranking_provider_name"]
  66. ).to_string()
  67. retrieval_model_changed = True
  68. click.echo(
  69. click.style(
  70. f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
  71. fg="white",
  72. )
  73. )
  74. try:
  75. # update provider name append with "langgenius/{provider_name}/{provider_name}"
  76. params = {"record_id": record_id}
  77. update_retrieval_model_sql = ""
  78. if retrieval_model and retrieval_model_changed:
  79. update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
  80. params["retrieval_model"] = json.dumps(retrieval_model)
  81. params["provider_name"] = ModelProviderID(provider_name).to_string()
  82. sql = f"""update {table_name}
  83. set {provider_column_name} =
  84. :provider_name
  85. {update_retrieval_model_sql}
  86. where id = :record_id"""
  87. conn.execute(sa.text(sql), params)
  88. click.echo(
  89. click.style(
  90. f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
  91. fg="green",
  92. )
  93. )
  94. except Exception:
  95. failed_ids.append(record_id)
  96. click.echo(
  97. click.style(
  98. f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
  99. fg="red",
  100. )
  101. )
  102. logger.exception(
  103. "[%s] Failed to migrate [%s] %s (%s)", processed_count, table_name, record_id, provider_name
  104. )
  105. continue
  106. current_iter_count += 1
  107. processed_count += 1
  108. if not current_iter_count:
  109. break
  110. click.echo(
  111. click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
  112. )
  113. @classmethod
  114. def migrate_db_records(cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]):
  115. click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
  116. processed_count = 0
  117. failed_ids = []
  118. last_id = "00000000-0000-0000-0000-000000000000"
  119. while True:
  120. sql = f"""
  121. SELECT id, {provider_column_name} AS provider_name
  122. FROM {table_name}
  123. WHERE {provider_column_name} NOT LIKE '%/%'
  124. AND {provider_column_name} IS NOT NULL
  125. AND {provider_column_name} != ''
  126. AND id > :last_id
  127. ORDER BY id ASC
  128. LIMIT 5000
  129. """
  130. params = {"last_id": last_id or ""}
  131. with db.engine.begin() as conn:
  132. rs = conn.execute(sa.text(sql), params)
  133. current_iter_count = 0
  134. batch_updates = []
  135. for i in rs:
  136. current_iter_count += 1
  137. processed_count += 1
  138. record_id = str(i.id)
  139. last_id = record_id
  140. provider_name = str(i.provider_name)
  141. if record_id in failed_ids:
  142. continue
  143. click.echo(
  144. click.style(
  145. f"[{processed_count}] Migrating [{table_name}] {record_id} ({provider_name})",
  146. fg="white",
  147. )
  148. )
  149. try:
  150. # update jina to langgenius/jina_tool/jina etc.
  151. updated_value = provider_cls(provider_name).to_string()
  152. batch_updates.append((updated_value, record_id))
  153. except Exception:
  154. failed_ids.append(record_id)
  155. click.echo(
  156. click.style(
  157. f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})",
  158. fg="red",
  159. )
  160. )
  161. logger.exception(
  162. "[%s] Failed to migrate [%s] %s (%s)", processed_count, table_name, record_id, provider_name
  163. )
  164. continue
  165. if batch_updates:
  166. update_sql = f"""
  167. UPDATE {table_name}
  168. SET {provider_column_name} = :updated_value
  169. WHERE id = :record_id
  170. """
  171. conn.execute(sa.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates])
  172. click.echo(
  173. click.style(
  174. f"[{processed_count}] Batch migrated [{len(batch_updates)}] records from [{table_name}]",
  175. fg="green",
  176. )
  177. )
  178. if not current_iter_count:
  179. break
  180. click.echo(
  181. click.style(f"Migrate [{table_name}] data for plugin completed, total: {processed_count}", fg="green")
  182. )