You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

clear_free_plan_tenant_expired_logs.py 19KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. import datetime
  2. import json
  3. import logging
  4. import time
  5. from concurrent.futures import ThreadPoolExecutor
  6. import click
  7. from flask import Flask, current_app
  8. from sqlalchemy import select
  9. from sqlalchemy.orm import Session, sessionmaker
  10. from configs import dify_config
  11. from core.model_runtime.utils.encoders import jsonable_encoder
  12. from extensions.ext_database import db
  13. from extensions.ext_storage import storage
  14. from models.account import Tenant
  15. from models.model import (
  16. App,
  17. AppAnnotationHitHistory,
  18. Conversation,
  19. Message,
  20. MessageAgentThought,
  21. MessageAnnotation,
  22. MessageChain,
  23. MessageFeedback,
  24. MessageFile,
  25. )
  26. from models.web import SavedMessage
  27. from models.workflow import WorkflowAppLog
  28. from repositories.factory import DifyAPIRepositoryFactory
  29. from services.billing_service import BillingService
  30. logger = logging.getLogger(__name__)
  31. class ClearFreePlanTenantExpiredLogs:
  32. @classmethod
  33. def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]):
  34. """
  35. Clean up message-related tables to avoid data redundancy.
  36. This method cleans up tables that have foreign key relationships with Message.
  37. Args:
  38. session: Database session, the same with the one in process_tenant method
  39. tenant_id: Tenant ID for logging purposes
  40. batch_message_ids: List of message IDs to clean up
  41. """
  42. if not batch_message_ids:
  43. return
  44. # Clean up each related table
  45. related_tables = [
  46. (MessageFeedback, "message_feedbacks"),
  47. (MessageFile, "message_files"),
  48. (MessageAnnotation, "message_annotations"),
  49. (MessageChain, "message_chains"),
  50. (MessageAgentThought, "message_agent_thoughts"),
  51. (AppAnnotationHitHistory, "app_annotation_hit_histories"),
  52. (SavedMessage, "saved_messages"),
  53. ]
  54. for model, table_name in related_tables:
  55. # Query records related to expired messages
  56. records = (
  57. session.query(model)
  58. .where(
  59. model.message_id.in_(batch_message_ids), # type: ignore
  60. )
  61. .all()
  62. )
  63. if len(records) == 0:
  64. continue
  65. # Save records before deletion
  66. record_ids = [record.id for record in records]
  67. try:
  68. record_data = []
  69. for record in records:
  70. try:
  71. if hasattr(record, "to_dict"):
  72. record_data.append(record.to_dict())
  73. else:
  74. # if record doesn't have to_dict method, we need to transform it to dict manually
  75. record_dict = {}
  76. for column in record.__table__.columns:
  77. record_dict[column.name] = getattr(record, column.name)
  78. record_data.append(record_dict)
  79. except Exception:
  80. logger.exception("Failed to transform %s record: %s", table_name, record.id)
  81. continue
  82. if record_data:
  83. storage.save(
  84. f"free_plan_tenant_expired_logs/"
  85. f"{tenant_id}/{table_name}/{datetime.datetime.now().strftime('%Y-%m-%d')}"
  86. f"-{time.time()}.json",
  87. json.dumps(
  88. jsonable_encoder(record_data),
  89. ).encode("utf-8"),
  90. )
  91. except Exception:
  92. logger.exception("Failed to save %s records", table_name)
  93. session.query(model).where(
  94. model.id.in_(record_ids), # type: ignore
  95. ).delete(synchronize_session=False)
  96. click.echo(
  97. click.style(
  98. f"[{datetime.datetime.now()}] Processed {len(record_ids)} "
  99. f"{table_name} records for tenant {tenant_id}"
  100. )
  101. )
  102. @classmethod
  103. def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int):
  104. with flask_app.app_context():
  105. apps = db.session.scalars(select(App).where(App.tenant_id == tenant_id)).all()
  106. app_ids = [app.id for app in apps]
  107. while True:
  108. with Session(db.engine).no_autoflush as session:
  109. messages = (
  110. session.query(Message)
  111. .where(
  112. Message.app_id.in_(app_ids),
  113. Message.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
  114. )
  115. .limit(batch)
  116. .all()
  117. )
  118. if len(messages) == 0:
  119. break
  120. storage.save(
  121. f"free_plan_tenant_expired_logs/"
  122. f"{tenant_id}/messages/{datetime.datetime.now().strftime('%Y-%m-%d')}"
  123. f"-{time.time()}.json",
  124. json.dumps(
  125. jsonable_encoder(
  126. [message.to_dict() for message in messages],
  127. ),
  128. ).encode("utf-8"),
  129. )
  130. message_ids = [message.id for message in messages]
  131. # delete messages
  132. session.query(Message).where(
  133. Message.id.in_(message_ids),
  134. ).delete(synchronize_session=False)
  135. cls._clear_message_related_tables(session, tenant_id, message_ids)
  136. session.commit()
  137. click.echo(
  138. click.style(
  139. f"[{datetime.datetime.now()}] Processed {len(message_ids)} messages for tenant {tenant_id} "
  140. )
  141. )
  142. while True:
  143. with Session(db.engine).no_autoflush as session:
  144. conversations = (
  145. session.query(Conversation)
  146. .where(
  147. Conversation.app_id.in_(app_ids),
  148. Conversation.updated_at < datetime.datetime.now() - datetime.timedelta(days=days),
  149. )
  150. .limit(batch)
  151. .all()
  152. )
  153. if len(conversations) == 0:
  154. break
  155. storage.save(
  156. f"free_plan_tenant_expired_logs/"
  157. f"{tenant_id}/conversations/{datetime.datetime.now().strftime('%Y-%m-%d')}"
  158. f"-{time.time()}.json",
  159. json.dumps(
  160. jsonable_encoder(
  161. [conversation.to_dict() for conversation in conversations],
  162. ),
  163. ).encode("utf-8"),
  164. )
  165. conversation_ids = [conversation.id for conversation in conversations]
  166. session.query(Conversation).where(
  167. Conversation.id.in_(conversation_ids),
  168. ).delete(synchronize_session=False)
  169. session.commit()
  170. click.echo(
  171. click.style(
  172. f"[{datetime.datetime.now()}] Processed {len(conversation_ids)}"
  173. f" conversations for tenant {tenant_id}"
  174. )
  175. )
  176. # Process expired workflow node executions with backup
  177. session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
  178. node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
  179. before_date = datetime.datetime.now() - datetime.timedelta(days=days)
  180. total_deleted = 0
  181. while True:
  182. # Get a batch of expired executions for backup
  183. workflow_node_executions = node_execution_repo.get_expired_executions_batch(
  184. tenant_id=tenant_id,
  185. before_date=before_date,
  186. batch_size=batch,
  187. )
  188. if len(workflow_node_executions) == 0:
  189. break
  190. # Save workflow node executions to storage
  191. storage.save(
  192. f"free_plan_tenant_expired_logs/"
  193. f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}"
  194. f"-{time.time()}.json",
  195. json.dumps(
  196. jsonable_encoder(workflow_node_executions),
  197. ).encode("utf-8"),
  198. )
  199. # Extract IDs for deletion
  200. workflow_node_execution_ids = [
  201. workflow_node_execution.id for workflow_node_execution in workflow_node_executions
  202. ]
  203. # Delete the backed up executions
  204. deleted_count = node_execution_repo.delete_executions_by_ids(workflow_node_execution_ids)
  205. total_deleted += deleted_count
  206. click.echo(
  207. click.style(
  208. f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}"
  209. f" workflow node executions for tenant {tenant_id}"
  210. )
  211. )
  212. # If we got fewer than the batch size, we're done
  213. if len(workflow_node_executions) < batch:
  214. break
  215. # Process expired workflow runs with backup
  216. session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
  217. workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
  218. before_date = datetime.datetime.now() - datetime.timedelta(days=days)
  219. total_deleted = 0
  220. while True:
  221. # Get a batch of expired workflow runs for backup
  222. workflow_runs = workflow_run_repo.get_expired_runs_batch(
  223. tenant_id=tenant_id,
  224. before_date=before_date,
  225. batch_size=batch,
  226. )
  227. if len(workflow_runs) == 0:
  228. break
  229. # Save workflow runs to storage
  230. storage.save(
  231. f"free_plan_tenant_expired_logs/"
  232. f"{tenant_id}/workflow_runs/{datetime.datetime.now().strftime('%Y-%m-%d')}"
  233. f"-{time.time()}.json",
  234. json.dumps(
  235. jsonable_encoder(
  236. [workflow_run.to_dict() for workflow_run in workflow_runs],
  237. ),
  238. ).encode("utf-8"),
  239. )
  240. # Extract IDs for deletion
  241. workflow_run_ids = [workflow_run.id for workflow_run in workflow_runs]
  242. # Delete the backed up workflow runs
  243. deleted_count = workflow_run_repo.delete_runs_by_ids(workflow_run_ids)
  244. total_deleted += deleted_count
  245. click.echo(
  246. click.style(
  247. f"[{datetime.datetime.now()}] Processed {len(workflow_run_ids)}"
  248. f" workflow runs for tenant {tenant_id}"
  249. )
  250. )
  251. # If we got fewer than the batch size, we're done
  252. if len(workflow_runs) < batch:
  253. break
  254. while True:
  255. with Session(db.engine).no_autoflush as session:
  256. workflow_app_logs = (
  257. session.query(WorkflowAppLog)
  258. .where(
  259. WorkflowAppLog.tenant_id == tenant_id,
  260. WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
  261. )
  262. .limit(batch)
  263. .all()
  264. )
  265. if len(workflow_app_logs) == 0:
  266. break
  267. # save workflow app logs
  268. storage.save(
  269. f"free_plan_tenant_expired_logs/"
  270. f"{tenant_id}/workflow_app_logs/{datetime.datetime.now().strftime('%Y-%m-%d')}"
  271. f"-{time.time()}.json",
  272. json.dumps(
  273. jsonable_encoder(
  274. [workflow_app_log.to_dict() for workflow_app_log in workflow_app_logs],
  275. ),
  276. ).encode("utf-8"),
  277. )
  278. workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs]
  279. # delete workflow app logs
  280. session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete(
  281. synchronize_session=False
  282. )
  283. session.commit()
  284. click.echo(
  285. click.style(
  286. f"[{datetime.datetime.now()}] Processed {len(workflow_app_log_ids)}"
  287. f" workflow app logs for tenant {tenant_id}"
  288. )
  289. )
  290. @classmethod
  291. def process(cls, days: int, batch: int, tenant_ids: list[str]):
  292. """
  293. Clear free plan tenant expired logs.
  294. """
  295. click.echo(click.style("Clearing free plan tenant expired logs", fg="white"))
  296. ended_at = datetime.datetime.now()
  297. started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
  298. current_time = started_at
  299. with Session(db.engine) as session:
  300. total_tenant_count = session.query(Tenant.id).count()
  301. click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
  302. handled_tenant_count = 0
  303. thread_pool = ThreadPoolExecutor(max_workers=10)
  304. def process_tenant(flask_app: Flask, tenant_id: str):
  305. try:
  306. if (
  307. not dify_config.BILLING_ENABLED
  308. or BillingService.get_info(tenant_id)["subscription"]["plan"] == "sandbox"
  309. ):
  310. # only process sandbox tenant
  311. cls.process_tenant(flask_app, tenant_id, days, batch)
  312. except Exception:
  313. logger.exception("Failed to process tenant %s", tenant_id)
  314. finally:
  315. nonlocal handled_tenant_count
  316. handled_tenant_count += 1
  317. if handled_tenant_count % 100 == 0:
  318. click.echo(
  319. click.style(
  320. f"[{datetime.datetime.now()}] "
  321. f"Processed {handled_tenant_count} tenants "
  322. f"({(handled_tenant_count / total_tenant_count) * 100:.1f}%), "
  323. f"{handled_tenant_count}/{total_tenant_count}",
  324. fg="green",
  325. )
  326. )
  327. futures = []
  328. if tenant_ids:
  329. for tenant_id in tenant_ids:
  330. futures.append(
  331. thread_pool.submit(
  332. process_tenant,
  333. current_app._get_current_object(), # type: ignore[attr-defined]
  334. tenant_id,
  335. )
  336. )
  337. else:
  338. while current_time < ended_at:
  339. click.echo(
  340. click.style(f"Current time: {current_time}, Started at: {datetime.datetime.now()}", fg="white")
  341. )
  342. # Initial interval of 1 day, will be dynamically adjusted based on tenant count
  343. interval = datetime.timedelta(days=1)
  344. # Process tenants in this batch
  345. with Session(db.engine) as session:
  346. # Calculate tenant count in next batch with current interval
  347. # Try different intervals until we find one with a reasonable tenant count
  348. test_intervals = [
  349. datetime.timedelta(days=1),
  350. datetime.timedelta(hours=12),
  351. datetime.timedelta(hours=6),
  352. datetime.timedelta(hours=3),
  353. datetime.timedelta(hours=1),
  354. ]
  355. tenant_count = 0
  356. for test_interval in test_intervals:
  357. tenant_count = (
  358. session.query(Tenant.id)
  359. .where(Tenant.created_at.between(current_time, current_time + test_interval))
  360. .count()
  361. )
  362. if tenant_count <= 100:
  363. interval = test_interval
  364. break
  365. else:
  366. # If all intervals have too many tenants, use minimum interval
  367. interval = datetime.timedelta(hours=1)
  368. # Adjust interval to target ~100 tenants per batch
  369. if tenant_count > 0:
  370. # Scale interval based on ratio to target count
  371. interval = min(
  372. datetime.timedelta(days=1), # Max 1 day
  373. max(
  374. datetime.timedelta(hours=1), # Min 1 hour
  375. interval * (100 / tenant_count), # Scale to target 100
  376. ),
  377. )
  378. batch_end = min(current_time + interval, ended_at)
  379. rs = (
  380. session.query(Tenant.id)
  381. .where(Tenant.created_at.between(current_time, batch_end))
  382. .order_by(Tenant.created_at)
  383. )
  384. tenants = []
  385. for row in rs:
  386. tenant_id = str(row.id)
  387. try:
  388. tenants.append(tenant_id)
  389. except Exception:
  390. logger.exception("Failed to process tenant %s", tenant_id)
  391. continue
  392. futures.append(
  393. thread_pool.submit(
  394. process_tenant,
  395. current_app._get_current_object(), # type: ignore[attr-defined]
  396. tenant_id,
  397. )
  398. )
  399. current_time = batch_end
  400. # wait for all threads to finish
  401. for future in futures:
  402. future.result()