You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

clear_free_plan_tenant_expired_logs.py 19KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. import datetime
  2. import json
  3. import logging
  4. import time
  5. from concurrent.futures import ThreadPoolExecutor
  6. import click
  7. from flask import Flask, current_app
  8. from sqlalchemy.orm import Session, sessionmaker
  9. from configs import dify_config
  10. from core.model_runtime.utils.encoders import jsonable_encoder
  11. from extensions.ext_database import db
  12. from extensions.ext_storage import storage
  13. from models.account import Tenant
  14. from models.model import (
  15. App,
  16. AppAnnotationHitHistory,
  17. Conversation,
  18. Message,
  19. MessageAgentThought,
  20. MessageAnnotation,
  21. MessageChain,
  22. MessageFeedback,
  23. MessageFile,
  24. )
  25. from models.web import SavedMessage
  26. from models.workflow import WorkflowAppLog
  27. from repositories.factory import DifyAPIRepositoryFactory
  28. from services.billing_service import BillingService
  29. logger = logging.getLogger(__name__)
  30. class ClearFreePlanTenantExpiredLogs:
  31. @classmethod
  32. def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]) -> None:
  33. """
  34. Clean up message-related tables to avoid data redundancy.
  35. This method cleans up tables that have foreign key relationships with Message.
  36. Args:
  37. session: Database session, the same with the one in process_tenant method
  38. tenant_id: Tenant ID for logging purposes
  39. batch_message_ids: List of message IDs to clean up
  40. """
  41. if not batch_message_ids:
  42. return
  43. # Clean up each related table
  44. related_tables = [
  45. (MessageFeedback, "message_feedbacks"),
  46. (MessageFile, "message_files"),
  47. (MessageAnnotation, "message_annotations"),
  48. (MessageChain, "message_chains"),
  49. (MessageAgentThought, "message_agent_thoughts"),
  50. (AppAnnotationHitHistory, "app_annotation_hit_histories"),
  51. (SavedMessage, "saved_messages"),
  52. ]
  53. for model, table_name in related_tables:
  54. # Query records related to expired messages
  55. records = (
  56. session.query(model)
  57. .where(
  58. model.message_id.in_(batch_message_ids), # type: ignore
  59. )
  60. .all()
  61. )
  62. if len(records) == 0:
  63. continue
  64. # Save records before deletion
  65. record_ids = [record.id for record in records]
  66. try:
  67. record_data = []
  68. for record in records:
  69. try:
  70. if hasattr(record, "to_dict"):
  71. record_data.append(record.to_dict())
  72. else:
  73. # if record doesn't have to_dict method, we need to transform it to dict manually
  74. record_dict = {}
  75. for column in record.__table__.columns:
  76. record_dict[column.name] = getattr(record, column.name)
  77. record_data.append(record_dict)
  78. except Exception:
  79. logger.exception("Failed to transform %s record: %s", table_name, record.id)
  80. continue
  81. if record_data:
  82. storage.save(
  83. f"free_plan_tenant_expired_logs/"
  84. f"{tenant_id}/{table_name}/{datetime.datetime.now().strftime('%Y-%m-%d')}"
  85. f"-{time.time()}.json",
  86. json.dumps(
  87. jsonable_encoder(record_data),
  88. ).encode("utf-8"),
  89. )
  90. except Exception:
  91. logger.exception("Failed to save %s records", table_name)
  92. session.query(model).where(
  93. model.id.in_(record_ids), # type: ignore
  94. ).delete(synchronize_session=False)
  95. click.echo(
  96. click.style(
  97. f"[{datetime.datetime.now()}] Processed {len(record_ids)} "
  98. f"{table_name} records for tenant {tenant_id}"
  99. )
  100. )
  101. @classmethod
  102. def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int):
  103. with flask_app.app_context():
  104. apps = db.session.query(App).where(App.tenant_id == tenant_id).all()
  105. app_ids = [app.id for app in apps]
  106. while True:
  107. with Session(db.engine).no_autoflush as session:
  108. messages = (
  109. session.query(Message)
  110. .where(
  111. Message.app_id.in_(app_ids),
  112. Message.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
  113. )
  114. .limit(batch)
  115. .all()
  116. )
  117. if len(messages) == 0:
  118. break
  119. storage.save(
  120. f"free_plan_tenant_expired_logs/"
  121. f"{tenant_id}/messages/{datetime.datetime.now().strftime('%Y-%m-%d')}"
  122. f"-{time.time()}.json",
  123. json.dumps(
  124. jsonable_encoder(
  125. [message.to_dict() for message in messages],
  126. ),
  127. ).encode("utf-8"),
  128. )
  129. message_ids = [message.id for message in messages]
  130. # delete messages
  131. session.query(Message).where(
  132. Message.id.in_(message_ids),
  133. ).delete(synchronize_session=False)
  134. cls._clear_message_related_tables(session, tenant_id, message_ids)
  135. session.commit()
  136. click.echo(
  137. click.style(
  138. f"[{datetime.datetime.now()}] Processed {len(message_ids)} messages for tenant {tenant_id} "
  139. )
  140. )
  141. while True:
  142. with Session(db.engine).no_autoflush as session:
  143. conversations = (
  144. session.query(Conversation)
  145. .where(
  146. Conversation.app_id.in_(app_ids),
  147. Conversation.updated_at < datetime.datetime.now() - datetime.timedelta(days=days),
  148. )
  149. .limit(batch)
  150. .all()
  151. )
  152. if len(conversations) == 0:
  153. break
  154. storage.save(
  155. f"free_plan_tenant_expired_logs/"
  156. f"{tenant_id}/conversations/{datetime.datetime.now().strftime('%Y-%m-%d')}"
  157. f"-{time.time()}.json",
  158. json.dumps(
  159. jsonable_encoder(
  160. [conversation.to_dict() for conversation in conversations],
  161. ),
  162. ).encode("utf-8"),
  163. )
  164. conversation_ids = [conversation.id for conversation in conversations]
  165. session.query(Conversation).where(
  166. Conversation.id.in_(conversation_ids),
  167. ).delete(synchronize_session=False)
  168. session.commit()
  169. click.echo(
  170. click.style(
  171. f"[{datetime.datetime.now()}] Processed {len(conversation_ids)}"
  172. f" conversations for tenant {tenant_id}"
  173. )
  174. )
  175. # Process expired workflow node executions with backup
  176. session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
  177. node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
  178. before_date = datetime.datetime.now() - datetime.timedelta(days=days)
  179. total_deleted = 0
  180. while True:
  181. # Get a batch of expired executions for backup
  182. workflow_node_executions = node_execution_repo.get_expired_executions_batch(
  183. tenant_id=tenant_id,
  184. before_date=before_date,
  185. batch_size=batch,
  186. )
  187. if len(workflow_node_executions) == 0:
  188. break
  189. # Save workflow node executions to storage
  190. storage.save(
  191. f"free_plan_tenant_expired_logs/"
  192. f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}"
  193. f"-{time.time()}.json",
  194. json.dumps(
  195. jsonable_encoder(workflow_node_executions),
  196. ).encode("utf-8"),
  197. )
  198. # Extract IDs for deletion
  199. workflow_node_execution_ids = [
  200. workflow_node_execution.id for workflow_node_execution in workflow_node_executions
  201. ]
  202. # Delete the backed up executions
  203. deleted_count = node_execution_repo.delete_executions_by_ids(workflow_node_execution_ids)
  204. total_deleted += deleted_count
  205. click.echo(
  206. click.style(
  207. f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}"
  208. f" workflow node executions for tenant {tenant_id}"
  209. )
  210. )
  211. # If we got fewer than the batch size, we're done
  212. if len(workflow_node_executions) < batch:
  213. break
  214. # Process expired workflow runs with backup
  215. session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
  216. workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
  217. before_date = datetime.datetime.now() - datetime.timedelta(days=days)
  218. total_deleted = 0
  219. while True:
  220. # Get a batch of expired workflow runs for backup
  221. workflow_runs = workflow_run_repo.get_expired_runs_batch(
  222. tenant_id=tenant_id,
  223. before_date=before_date,
  224. batch_size=batch,
  225. )
  226. if len(workflow_runs) == 0:
  227. break
  228. # Save workflow runs to storage
  229. storage.save(
  230. f"free_plan_tenant_expired_logs/"
  231. f"{tenant_id}/workflow_runs/{datetime.datetime.now().strftime('%Y-%m-%d')}"
  232. f"-{time.time()}.json",
  233. json.dumps(
  234. jsonable_encoder(
  235. [workflow_run.to_dict() for workflow_run in workflow_runs],
  236. ),
  237. ).encode("utf-8"),
  238. )
  239. # Extract IDs for deletion
  240. workflow_run_ids = [workflow_run.id for workflow_run in workflow_runs]
  241. # Delete the backed up workflow runs
  242. deleted_count = workflow_run_repo.delete_runs_by_ids(workflow_run_ids)
  243. total_deleted += deleted_count
  244. click.echo(
  245. click.style(
  246. f"[{datetime.datetime.now()}] Processed {len(workflow_run_ids)}"
  247. f" workflow runs for tenant {tenant_id}"
  248. )
  249. )
  250. # If we got fewer than the batch size, we're done
  251. if len(workflow_runs) < batch:
  252. break
  253. while True:
  254. with Session(db.engine).no_autoflush as session:
  255. workflow_app_logs = (
  256. session.query(WorkflowAppLog)
  257. .where(
  258. WorkflowAppLog.tenant_id == tenant_id,
  259. WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
  260. )
  261. .limit(batch)
  262. .all()
  263. )
  264. if len(workflow_app_logs) == 0:
  265. break
  266. # save workflow app logs
  267. storage.save(
  268. f"free_plan_tenant_expired_logs/"
  269. f"{tenant_id}/workflow_app_logs/{datetime.datetime.now().strftime('%Y-%m-%d')}"
  270. f"-{time.time()}.json",
  271. json.dumps(
  272. jsonable_encoder(
  273. [workflow_app_log.to_dict() for workflow_app_log in workflow_app_logs],
  274. ),
  275. ).encode("utf-8"),
  276. )
  277. workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs]
  278. # delete workflow app logs
  279. session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete(
  280. synchronize_session=False
  281. )
  282. session.commit()
  283. click.echo(
  284. click.style(
  285. f"[{datetime.datetime.now()}] Processed {len(workflow_app_log_ids)}"
  286. f" workflow app logs for tenant {tenant_id}"
  287. )
  288. )
  289. @classmethod
  290. def process(cls, days: int, batch: int, tenant_ids: list[str]):
  291. """
  292. Clear free plan tenant expired logs.
  293. """
  294. click.echo(click.style("Clearing free plan tenant expired logs", fg="white"))
  295. ended_at = datetime.datetime.now()
  296. started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
  297. current_time = started_at
  298. with Session(db.engine) as session:
  299. total_tenant_count = session.query(Tenant.id).count()
  300. click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
  301. handled_tenant_count = 0
  302. thread_pool = ThreadPoolExecutor(max_workers=10)
  303. def process_tenant(flask_app: Flask, tenant_id: str) -> None:
  304. try:
  305. if (
  306. not dify_config.BILLING_ENABLED
  307. or BillingService.get_info(tenant_id)["subscription"]["plan"] == "sandbox"
  308. ):
  309. # only process sandbox tenant
  310. cls.process_tenant(flask_app, tenant_id, days, batch)
  311. except Exception:
  312. logger.exception("Failed to process tenant %s", tenant_id)
  313. finally:
  314. nonlocal handled_tenant_count
  315. handled_tenant_count += 1
  316. if handled_tenant_count % 100 == 0:
  317. click.echo(
  318. click.style(
  319. f"[{datetime.datetime.now()}] "
  320. f"Processed {handled_tenant_count} tenants "
  321. f"({(handled_tenant_count / total_tenant_count) * 100:.1f}%), "
  322. f"{handled_tenant_count}/{total_tenant_count}",
  323. fg="green",
  324. )
  325. )
  326. futures = []
  327. if tenant_ids:
  328. for tenant_id in tenant_ids:
  329. futures.append(
  330. thread_pool.submit(
  331. process_tenant,
  332. current_app._get_current_object(), # type: ignore[attr-defined]
  333. tenant_id,
  334. )
  335. )
  336. else:
  337. while current_time < ended_at:
  338. click.echo(
  339. click.style(f"Current time: {current_time}, Started at: {datetime.datetime.now()}", fg="white")
  340. )
  341. # Initial interval of 1 day, will be dynamically adjusted based on tenant count
  342. interval = datetime.timedelta(days=1)
  343. # Process tenants in this batch
  344. with Session(db.engine) as session:
  345. # Calculate tenant count in next batch with current interval
  346. # Try different intervals until we find one with a reasonable tenant count
  347. test_intervals = [
  348. datetime.timedelta(days=1),
  349. datetime.timedelta(hours=12),
  350. datetime.timedelta(hours=6),
  351. datetime.timedelta(hours=3),
  352. datetime.timedelta(hours=1),
  353. ]
  354. for test_interval in test_intervals:
  355. tenant_count = (
  356. session.query(Tenant.id)
  357. .where(Tenant.created_at.between(current_time, current_time + test_interval))
  358. .count()
  359. )
  360. if tenant_count <= 100:
  361. interval = test_interval
  362. break
  363. else:
  364. # If all intervals have too many tenants, use minimum interval
  365. interval = datetime.timedelta(hours=1)
  366. # Adjust interval to target ~100 tenants per batch
  367. if tenant_count > 0:
  368. # Scale interval based on ratio to target count
  369. interval = min(
  370. datetime.timedelta(days=1), # Max 1 day
  371. max(
  372. datetime.timedelta(hours=1), # Min 1 hour
  373. interval * (100 / tenant_count), # Scale to target 100
  374. ),
  375. )
  376. batch_end = min(current_time + interval, ended_at)
  377. rs = (
  378. session.query(Tenant.id)
  379. .where(Tenant.created_at.between(current_time, batch_end))
  380. .order_by(Tenant.created_at)
  381. )
  382. tenants = []
  383. for row in rs:
  384. tenant_id = str(row.id)
  385. try:
  386. tenants.append(tenant_id)
  387. except Exception:
  388. logger.exception("Failed to process tenant %s", tenant_id)
  389. continue
  390. futures.append(
  391. thread_pool.submit(
  392. process_tenant,
  393. current_app._get_current_object(), # type: ignore[attr-defined]
  394. tenant_id,
  395. )
  396. )
  397. current_time = batch_end
  398. # wait for all threads to finish
  399. for future in futures:
  400. future.result()