You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

plugin_migration.py 23KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602
  1. import datetime
  2. import json
  3. import logging
  4. import time
  5. from collections.abc import Mapping, Sequence
  6. from concurrent.futures import ThreadPoolExecutor
  7. from pathlib import Path
  8. from typing import Any, Optional
  9. from uuid import uuid4
  10. import click
  11. import sqlalchemy as sa
  12. import tqdm
  13. from flask import Flask, current_app
  14. from sqlalchemy.orm import Session
  15. from core.agent.entities import AgentToolEntity
  16. from core.helper import marketplace
  17. from core.plugin.entities.plugin import PluginInstallationSource
  18. from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus
  19. from core.plugin.impl.plugin import PluginInstaller
  20. from core.tools.entities.tool_entities import ToolProviderType
  21. from extensions.ext_database import db
  22. from models.account import Tenant
  23. from models.model import App, AppMode, AppModelConfig
  24. from models.provider_ids import ModelProviderID, ToolProviderID
  25. from models.tools import BuiltinToolProvider
  26. from models.workflow import Workflow
  27. logger = logging.getLogger(__name__)
  28. excluded_providers = ["time", "audio", "code", "webscraper"]
  29. class PluginMigration:
  30. @classmethod
  31. def extract_plugins(cls, filepath: str, workers: int) -> None:
  32. """
  33. Migrate plugin.
  34. """
  35. from threading import Lock
  36. click.echo(click.style("Migrating models/tools to new plugin Mechanism", fg="white"))
  37. ended_at = datetime.datetime.now()
  38. started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
  39. current_time = started_at
  40. with Session(db.engine) as session:
  41. total_tenant_count = session.query(Tenant.id).count()
  42. click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
  43. handled_tenant_count = 0
  44. file_lock = Lock()
  45. counter_lock = Lock()
  46. thread_pool = ThreadPoolExecutor(max_workers=workers)
  47. def process_tenant(flask_app: Flask, tenant_id: str) -> None:
  48. with flask_app.app_context():
  49. nonlocal handled_tenant_count
  50. try:
  51. plugins = cls.extract_installed_plugin_ids(tenant_id)
  52. # Use lock when writing to file
  53. with file_lock:
  54. with open(filepath, "a") as f:
  55. f.write(json.dumps({"tenant_id": tenant_id, "plugins": plugins}) + "\n")
  56. # Use lock when updating counter
  57. with counter_lock:
  58. nonlocal handled_tenant_count
  59. handled_tenant_count += 1
  60. click.echo(
  61. click.style(
  62. f"[{datetime.datetime.now()}] "
  63. f"Processed {handled_tenant_count} tenants "
  64. f"({(handled_tenant_count / total_tenant_count) * 100:.1f}%), "
  65. f"{handled_tenant_count}/{total_tenant_count}",
  66. fg="green",
  67. )
  68. )
  69. except Exception:
  70. logger.exception("Failed to process tenant %s", tenant_id)
  71. futures = []
  72. while current_time < ended_at:
  73. click.echo(click.style(f"Current time: {current_time}, Started at: {datetime.datetime.now()}", fg="white"))
  74. # Initial interval of 1 day, will be dynamically adjusted based on tenant count
  75. interval = datetime.timedelta(days=1)
  76. # Process tenants in this batch
  77. with Session(db.engine) as session:
  78. # Calculate tenant count in next batch with current interval
  79. # Try different intervals until we find one with a reasonable tenant count
  80. test_intervals = [
  81. datetime.timedelta(days=1),
  82. datetime.timedelta(hours=12),
  83. datetime.timedelta(hours=6),
  84. datetime.timedelta(hours=3),
  85. datetime.timedelta(hours=1),
  86. ]
  87. for test_interval in test_intervals:
  88. tenant_count = (
  89. session.query(Tenant.id)
  90. .where(Tenant.created_at.between(current_time, current_time + test_interval))
  91. .count()
  92. )
  93. if tenant_count <= 100:
  94. interval = test_interval
  95. break
  96. else:
  97. # If all intervals have too many tenants, use minimum interval
  98. interval = datetime.timedelta(hours=1)
  99. # Adjust interval to target ~100 tenants per batch
  100. if tenant_count > 0:
  101. # Scale interval based on ratio to target count
  102. interval = min(
  103. datetime.timedelta(days=1), # Max 1 day
  104. max(
  105. datetime.timedelta(hours=1), # Min 1 hour
  106. interval * (100 / tenant_count), # Scale to target 100
  107. ),
  108. )
  109. batch_end = min(current_time + interval, ended_at)
  110. rs = (
  111. session.query(Tenant.id)
  112. .where(Tenant.created_at.between(current_time, batch_end))
  113. .order_by(Tenant.created_at)
  114. )
  115. tenants = []
  116. for row in rs:
  117. tenant_id = str(row.id)
  118. try:
  119. tenants.append(tenant_id)
  120. except Exception:
  121. logger.exception("Failed to process tenant %s", tenant_id)
  122. continue
  123. futures.append(
  124. thread_pool.submit(
  125. process_tenant,
  126. current_app._get_current_object(), # type: ignore[attr-defined]
  127. tenant_id,
  128. )
  129. )
  130. current_time = batch_end
  131. # wait for all threads to finish
  132. for future in futures:
  133. future.result()
  134. @classmethod
  135. def extract_installed_plugin_ids(cls, tenant_id: str) -> Sequence[str]:
  136. """
  137. Extract installed plugin ids.
  138. """
  139. tools = cls.extract_tool_tables(tenant_id)
  140. models = cls.extract_model_tables(tenant_id)
  141. workflows = cls.extract_workflow_tables(tenant_id)
  142. apps = cls.extract_app_tables(tenant_id)
  143. return list({*tools, *models, *workflows, *apps})
  144. @classmethod
  145. def extract_model_tables(cls, tenant_id: str) -> Sequence[str]:
  146. """
  147. Extract model tables.
  148. """
  149. models: list[str] = []
  150. table_pairs = [
  151. ("providers", "provider_name"),
  152. ("provider_models", "provider_name"),
  153. ("provider_orders", "provider_name"),
  154. ("tenant_default_models", "provider_name"),
  155. ("tenant_preferred_model_providers", "provider_name"),
  156. ("provider_model_settings", "provider_name"),
  157. ("load_balancing_model_configs", "provider_name"),
  158. ]
  159. for table, column in table_pairs:
  160. models.extend(cls.extract_model_table(tenant_id, table, column))
  161. # duplicate models
  162. models = list(set(models))
  163. return models
  164. @classmethod
  165. def extract_model_table(cls, tenant_id: str, table: str, column: str) -> Sequence[str]:
  166. """
  167. Extract model table.
  168. """
  169. with Session(db.engine) as session:
  170. rs = session.execute(
  171. sa.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id}
  172. )
  173. result = []
  174. for row in rs:
  175. provider_name = str(row[0])
  176. result.append(ModelProviderID(provider_name).plugin_id)
  177. return result
  178. @classmethod
  179. def extract_tool_tables(cls, tenant_id: str) -> Sequence[str]:
  180. """
  181. Extract tool tables.
  182. """
  183. with Session(db.engine) as session:
  184. rs = session.query(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant_id).all()
  185. result = []
  186. for row in rs:
  187. result.append(ToolProviderID(row.provider).plugin_id)
  188. return result
  189. @classmethod
  190. def extract_workflow_tables(cls, tenant_id: str) -> Sequence[str]:
  191. """
  192. Extract workflow tables, only ToolNode is required.
  193. """
  194. with Session(db.engine) as session:
  195. rs = session.query(Workflow).where(Workflow.tenant_id == tenant_id).all()
  196. result = []
  197. for row in rs:
  198. graph = row.graph_dict
  199. # get nodes
  200. nodes = graph.get("nodes", [])
  201. for node in nodes:
  202. data = node.get("data", {})
  203. if data.get("type") == "tool":
  204. provider_name = data.get("provider_name")
  205. provider_type = data.get("provider_type")
  206. if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN.value:
  207. result.append(ToolProviderID(provider_name).plugin_id)
  208. return result
  209. @classmethod
  210. def extract_app_tables(cls, tenant_id: str) -> Sequence[str]:
  211. """
  212. Extract app tables.
  213. """
  214. with Session(db.engine) as session:
  215. apps = session.query(App).where(App.tenant_id == tenant_id).all()
  216. if not apps:
  217. return []
  218. agent_app_model_config_ids = [
  219. app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value
  220. ]
  221. rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
  222. result = []
  223. for row in rs:
  224. agent_config = row.agent_mode_dict
  225. if "tools" in agent_config and isinstance(agent_config["tools"], list):
  226. for tool in agent_config["tools"]:
  227. if isinstance(tool, dict):
  228. try:
  229. tool_entity = AgentToolEntity(**tool)
  230. if (
  231. tool_entity.provider_type == ToolProviderType.BUILT_IN.value
  232. and tool_entity.provider_id not in excluded_providers
  233. ):
  234. result.append(ToolProviderID(tool_entity.provider_id).plugin_id)
  235. except Exception:
  236. logger.exception("Failed to process tool %s", tool)
  237. continue
  238. return result
  239. @classmethod
  240. def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> Optional[str]:
  241. """
  242. Fetch plugin unique identifier using plugin id.
  243. """
  244. plugin_manifest = marketplace.batch_fetch_plugin_manifests([plugin_id])
  245. if not plugin_manifest:
  246. return None
  247. return plugin_manifest[0].latest_package_identifier
  248. @classmethod
  249. def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str) -> None:
  250. """
  251. Extract unique plugins.
  252. """
  253. Path(output_file).write_text(json.dumps(cls.extract_unique_plugins(extracted_plugins)))
  254. @classmethod
  255. def extract_unique_plugins(cls, extracted_plugins: str) -> Mapping[str, Any]:
  256. plugins: dict[str, str] = {}
  257. plugin_ids = []
  258. plugin_not_exist = []
  259. logger.info("Extracting unique plugins from %s", extracted_plugins)
  260. with open(extracted_plugins) as f:
  261. for line in f:
  262. data = json.loads(line)
  263. new_plugin_ids = data.get("plugins", [])
  264. for plugin_id in new_plugin_ids:
  265. if plugin_id not in plugin_ids:
  266. plugin_ids.append(plugin_id)
  267. def fetch_plugin(plugin_id):
  268. try:
  269. unique_identifier = cls._fetch_plugin_unique_identifier(plugin_id)
  270. if unique_identifier:
  271. plugins[plugin_id] = unique_identifier
  272. else:
  273. plugin_not_exist.append(plugin_id)
  274. except Exception:
  275. logger.exception("Failed to fetch plugin unique identifier for %s", plugin_id)
  276. plugin_not_exist.append(plugin_id)
  277. with ThreadPoolExecutor(max_workers=10) as executor:
  278. list(tqdm.tqdm(executor.map(fetch_plugin, plugin_ids), total=len(plugin_ids)))
  279. return {"plugins": plugins, "plugin_not_exist": plugin_not_exist}
  280. @classmethod
  281. def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100) -> None:
  282. """
  283. Install plugins.
  284. """
  285. manager = PluginInstaller()
  286. plugins = cls.extract_unique_plugins(extracted_plugins)
  287. not_installed = []
  288. plugin_install_failed = []
  289. # use a fake tenant id to install all the plugins
  290. fake_tenant_id = uuid4().hex
  291. logger.info("Installing %s plugin instances for fake tenant %s", len(plugins["plugins"]), fake_tenant_id)
  292. thread_pool = ThreadPoolExecutor(max_workers=workers)
  293. response = cls.handle_plugin_instance_install(fake_tenant_id, plugins["plugins"])
  294. if response.get("failed"):
  295. plugin_install_failed.extend(response.get("failed", []))
  296. def install(tenant_id: str, plugin_ids: list[str]) -> None:
  297. logger.info("Installing %s plugins for tenant %s", len(plugin_ids), tenant_id)
  298. # fetch plugin already installed
  299. installed_plugins = manager.list_plugins(tenant_id)
  300. installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
  301. # at most 64 plugins one batch
  302. for i in range(0, len(plugin_ids), 64):
  303. batch_plugin_ids = plugin_ids[i : i + 64]
  304. batch_plugin_identifiers = [
  305. plugins["plugins"][plugin_id]
  306. for plugin_id in batch_plugin_ids
  307. if plugin_id not in installed_plugins_ids and plugin_id in plugins["plugins"]
  308. ]
  309. manager.install_from_identifiers(
  310. tenant_id,
  311. batch_plugin_identifiers,
  312. PluginInstallationSource.Marketplace,
  313. metas=[
  314. {
  315. "plugin_unique_identifier": identifier,
  316. }
  317. for identifier in batch_plugin_identifiers
  318. ],
  319. )
  320. with open(extracted_plugins) as f:
  321. """
  322. Read line by line, and install plugins for each tenant.
  323. """
  324. for line in f:
  325. data = json.loads(line)
  326. tenant_id = data.get("tenant_id")
  327. plugin_ids = data.get("plugins", [])
  328. current_not_installed = {
  329. "tenant_id": tenant_id,
  330. "plugin_not_exist": [],
  331. }
  332. # get plugin unique identifier
  333. for plugin_id in plugin_ids:
  334. unique_identifier = plugins.get(plugin_id)
  335. if unique_identifier:
  336. current_not_installed["plugin_not_exist"].append(plugin_id)
  337. if current_not_installed["plugin_not_exist"]:
  338. not_installed.append(current_not_installed)
  339. thread_pool.submit(install, tenant_id, plugin_ids)
  340. thread_pool.shutdown(wait=True)
  341. logger.info("Uninstall plugins")
  342. # get installation
  343. try:
  344. installation = manager.list_plugins(fake_tenant_id)
  345. while installation:
  346. for plugin in installation:
  347. manager.uninstall(fake_tenant_id, plugin.installation_id)
  348. installation = manager.list_plugins(fake_tenant_id)
  349. except Exception:
  350. logger.exception("Failed to get installation for tenant %s", fake_tenant_id)
  351. Path(output_file).write_text(
  352. json.dumps(
  353. {
  354. "not_installed": not_installed,
  355. "plugin_install_failed": plugin_install_failed,
  356. }
  357. )
  358. )
  359. @classmethod
  360. def install_rag_pipeline_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100) -> None:
  361. """
  362. Install rag pipeline plugins.
  363. """
  364. manager = PluginInstaller()
  365. plugins = cls.extract_unique_plugins(extracted_plugins)
  366. plugin_install_failed = []
  367. # use a fake tenant id to install all the plugins
  368. fake_tenant_id = uuid4().hex
  369. logger.info("Installing %s plugin instances for fake tenant %s", len(plugins["plugins"]), fake_tenant_id)
  370. thread_pool = ThreadPoolExecutor(max_workers=workers)
  371. response = cls.handle_plugin_instance_install(fake_tenant_id, plugins["plugins"])
  372. if response.get("failed"):
  373. plugin_install_failed.extend(response.get("failed", []))
  374. def install(
  375. tenant_id: str, plugin_ids: dict[str, str], total_success_tenant: int, total_failed_tenant: int
  376. ) -> None:
  377. logger.info("Installing %s plugins for tenant %s", len(plugin_ids), tenant_id)
  378. try:
  379. # fetch plugin already installed
  380. installed_plugins = manager.list_plugins(tenant_id)
  381. installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
  382. # at most 64 plugins one batch
  383. for i in range(0, len(plugin_ids), 64):
  384. batch_plugin_ids = list(plugin_ids.keys())[i : i + 64]
  385. batch_plugin_identifiers = [
  386. plugin_ids[plugin_id]
  387. for plugin_id in batch_plugin_ids
  388. if plugin_id not in installed_plugins_ids and plugin_id in plugin_ids
  389. ]
  390. manager.install_from_identifiers(
  391. tenant_id,
  392. batch_plugin_identifiers,
  393. PluginInstallationSource.Marketplace,
  394. metas=[
  395. {
  396. "plugin_unique_identifier": identifier,
  397. }
  398. for identifier in batch_plugin_identifiers
  399. ],
  400. )
  401. total_success_tenant += 1
  402. except Exception:
  403. logger.exception("Failed to install plugins for tenant %s", tenant_id)
  404. total_failed_tenant += 1
  405. page = 1
  406. total_success_tenant = 0
  407. total_failed_tenant = 0
  408. while True:
  409. # paginate
  410. tenants = db.paginate(db.select(Tenant).order_by(Tenant.created_at.desc()), page=page, per_page=100)
  411. if tenants.items is None or len(tenants.items) == 0:
  412. break
  413. for tenant in tenants:
  414. tenant_id = tenant.id
  415. # get plugin unique identifier
  416. thread_pool.submit(
  417. install,
  418. tenant_id,
  419. plugins.get("plugins", {}),
  420. total_success_tenant,
  421. total_failed_tenant,
  422. )
  423. page += 1
  424. thread_pool.shutdown(wait=True)
  425. # uninstall all the plugins for fake tenant
  426. try:
  427. installation = manager.list_plugins(fake_tenant_id)
  428. while installation:
  429. for plugin in installation:
  430. manager.uninstall(fake_tenant_id, plugin.installation_id)
  431. installation = manager.list_plugins(fake_tenant_id)
  432. except Exception:
  433. logger.exception("Failed to get installation for tenant %s", fake_tenant_id)
  434. Path(output_file).write_text(
  435. json.dumps(
  436. {
  437. "total_success_tenant": total_success_tenant,
  438. "total_failed_tenant": total_failed_tenant,
  439. "plugin_install_failed": plugin_install_failed,
  440. }
  441. )
  442. )
  443. @classmethod
  444. def handle_plugin_instance_install(
  445. cls, tenant_id: str, plugin_identifiers_map: Mapping[str, str]
  446. ) -> Mapping[str, Any]:
  447. """
  448. Install plugins for a tenant.
  449. """
  450. manager = PluginInstaller()
  451. # download all the plugins and upload
  452. thread_pool = ThreadPoolExecutor(max_workers=10)
  453. futures = []
  454. for plugin_id, plugin_identifier in plugin_identifiers_map.items():
  455. def download_and_upload(tenant_id, plugin_id, plugin_identifier):
  456. plugin_package = marketplace.download_plugin_pkg(plugin_identifier)
  457. if not plugin_package:
  458. raise Exception(f"Failed to download plugin {plugin_identifier}")
  459. # upload
  460. manager.upload_pkg(tenant_id, plugin_package, verify_signature=True)
  461. futures.append(thread_pool.submit(download_and_upload, tenant_id, plugin_id, plugin_identifier))
  462. # Wait for all downloads to complete
  463. for future in futures:
  464. future.result() # This will raise any exceptions that occurred
  465. thread_pool.shutdown(wait=True)
  466. success = []
  467. failed = []
  468. reverse_map = {v: k for k, v in plugin_identifiers_map.items()}
  469. # at most 8 plugins one batch
  470. for i in range(0, len(plugin_identifiers_map), 8):
  471. batch_plugin_ids = list(plugin_identifiers_map.keys())[i : i + 8]
  472. batch_plugin_identifiers = [plugin_identifiers_map[plugin_id] for plugin_id in batch_plugin_ids]
  473. try:
  474. response = manager.install_from_identifiers(
  475. tenant_id=tenant_id,
  476. identifiers=batch_plugin_identifiers,
  477. source=PluginInstallationSource.Marketplace,
  478. metas=[
  479. {
  480. "plugin_unique_identifier": identifier,
  481. }
  482. for identifier in batch_plugin_identifiers
  483. ],
  484. )
  485. except Exception:
  486. # add to failed
  487. failed.extend(batch_plugin_identifiers)
  488. continue
  489. if response.all_installed:
  490. success.extend(batch_plugin_identifiers)
  491. continue
  492. task_id = response.task_id
  493. done = False
  494. while not done:
  495. status = manager.fetch_plugin_installation_task(tenant_id, task_id)
  496. if status.status in [PluginInstallTaskStatus.Failed, PluginInstallTaskStatus.Success]:
  497. for plugin in status.plugins:
  498. if plugin.status == PluginInstallTaskStatus.Success:
  499. success.append(reverse_map[plugin.plugin_unique_identifier])
  500. else:
  501. failed.append(reverse_map[plugin.plugin_unique_identifier])
  502. logger.error(
  503. "Failed to install plugin %s, error: %s",
  504. plugin.plugin_unique_identifier,
  505. plugin.message,
  506. )
  507. done = True
  508. else:
  509. time.sleep(1)
  510. return {"success": success, "failed": failed}