You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

clear_free_plan_tenant_expired_logs.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. import datetime
  2. import json
  3. import logging
  4. import time
  5. from concurrent.futures import ThreadPoolExecutor
  6. import click
  7. from flask import Flask, current_app
  8. from sqlalchemy.orm import Session, sessionmaker
  9. from configs import dify_config
  10. from core.model_runtime.utils.encoders import jsonable_encoder
  11. from extensions.ext_database import db
  12. from extensions.ext_storage import storage
  13. from models.account import Tenant
  14. from models.model import App, Conversation, Message
  15. from repositories.factory import DifyAPIRepositoryFactory
  16. from services.billing_service import BillingService
  17. logger = logging.getLogger(__name__)
  18. class ClearFreePlanTenantExpiredLogs:
  19. @classmethod
  20. def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int):
  21. with flask_app.app_context():
  22. apps = db.session.query(App).filter(App.tenant_id == tenant_id).all()
  23. app_ids = [app.id for app in apps]
  24. while True:
  25. with Session(db.engine).no_autoflush as session:
  26. messages = (
  27. session.query(Message)
  28. .filter(
  29. Message.app_id.in_(app_ids),
  30. Message.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
  31. )
  32. .limit(batch)
  33. .all()
  34. )
  35. if len(messages) == 0:
  36. break
  37. storage.save(
  38. f"free_plan_tenant_expired_logs/"
  39. f"{tenant_id}/messages/{datetime.datetime.now().strftime('%Y-%m-%d')}"
  40. f"-{time.time()}.json",
  41. json.dumps(
  42. jsonable_encoder(
  43. [message.to_dict() for message in messages],
  44. ),
  45. ).encode("utf-8"),
  46. )
  47. message_ids = [message.id for message in messages]
  48. # delete messages
  49. session.query(Message).filter(
  50. Message.id.in_(message_ids),
  51. ).delete(synchronize_session=False)
  52. session.commit()
  53. click.echo(
  54. click.style(
  55. f"[{datetime.datetime.now()}] Processed {len(message_ids)} messages for tenant {tenant_id} "
  56. )
  57. )
  58. while True:
  59. with Session(db.engine).no_autoflush as session:
  60. conversations = (
  61. session.query(Conversation)
  62. .filter(
  63. Conversation.app_id.in_(app_ids),
  64. Conversation.updated_at < datetime.datetime.now() - datetime.timedelta(days=days),
  65. )
  66. .limit(batch)
  67. .all()
  68. )
  69. if len(conversations) == 0:
  70. break
  71. storage.save(
  72. f"free_plan_tenant_expired_logs/"
  73. f"{tenant_id}/conversations/{datetime.datetime.now().strftime('%Y-%m-%d')}"
  74. f"-{time.time()}.json",
  75. json.dumps(
  76. jsonable_encoder(
  77. [conversation.to_dict() for conversation in conversations],
  78. ),
  79. ).encode("utf-8"),
  80. )
  81. conversation_ids = [conversation.id for conversation in conversations]
  82. session.query(Conversation).filter(
  83. Conversation.id.in_(conversation_ids),
  84. ).delete(synchronize_session=False)
  85. session.commit()
  86. click.echo(
  87. click.style(
  88. f"[{datetime.datetime.now()}] Processed {len(conversation_ids)}"
  89. f" conversations for tenant {tenant_id}"
  90. )
  91. )
  92. # Process expired workflow node executions with backup
  93. session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
  94. node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
  95. before_date = datetime.datetime.now() - datetime.timedelta(days=days)
  96. total_deleted = 0
  97. while True:
  98. # Get a batch of expired executions for backup
  99. workflow_node_executions = node_execution_repo.get_expired_executions_batch(
  100. tenant_id=tenant_id,
  101. before_date=before_date,
  102. batch_size=batch,
  103. )
  104. if len(workflow_node_executions) == 0:
  105. break
  106. # Save workflow node executions to storage
  107. storage.save(
  108. f"free_plan_tenant_expired_logs/"
  109. f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}"
  110. f"-{time.time()}.json",
  111. json.dumps(
  112. jsonable_encoder(workflow_node_executions),
  113. ).encode("utf-8"),
  114. )
  115. # Extract IDs for deletion
  116. workflow_node_execution_ids = [
  117. workflow_node_execution.id for workflow_node_execution in workflow_node_executions
  118. ]
  119. # Delete the backed up executions
  120. deleted_count = node_execution_repo.delete_executions_by_ids(workflow_node_execution_ids)
  121. total_deleted += deleted_count
  122. click.echo(
  123. click.style(
  124. f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}"
  125. f" workflow node executions for tenant {tenant_id}"
  126. )
  127. )
  128. # If we got fewer than the batch size, we're done
  129. if len(workflow_node_executions) < batch:
  130. break
  131. # Process expired workflow runs with backup
  132. session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
  133. workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
  134. before_date = datetime.datetime.now() - datetime.timedelta(days=days)
  135. total_deleted = 0
  136. while True:
  137. # Get a batch of expired workflow runs for backup
  138. workflow_runs = workflow_run_repo.get_expired_runs_batch(
  139. tenant_id=tenant_id,
  140. before_date=before_date,
  141. batch_size=batch,
  142. )
  143. if len(workflow_runs) == 0:
  144. break
  145. # Save workflow runs to storage
  146. storage.save(
  147. f"free_plan_tenant_expired_logs/"
  148. f"{tenant_id}/workflow_runs/{datetime.datetime.now().strftime('%Y-%m-%d')}"
  149. f"-{time.time()}.json",
  150. json.dumps(
  151. jsonable_encoder(
  152. [workflow_run.to_dict() for workflow_run in workflow_runs],
  153. ),
  154. ).encode("utf-8"),
  155. )
  156. # Extract IDs for deletion
  157. workflow_run_ids = [workflow_run.id for workflow_run in workflow_runs]
  158. # Delete the backed up workflow runs
  159. deleted_count = workflow_run_repo.delete_runs_by_ids(workflow_run_ids)
  160. total_deleted += deleted_count
  161. click.echo(
  162. click.style(
  163. f"[{datetime.datetime.now()}] Processed {len(workflow_run_ids)}"
  164. f" workflow runs for tenant {tenant_id}"
  165. )
  166. )
  167. # If we got fewer than the batch size, we're done
  168. if len(workflow_runs) < batch:
  169. break
  170. @classmethod
  171. def process(cls, days: int, batch: int, tenant_ids: list[str]):
  172. """
  173. Clear free plan tenant expired logs.
  174. """
  175. click.echo(click.style("Clearing free plan tenant expired logs", fg="white"))
  176. ended_at = datetime.datetime.now()
  177. started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
  178. current_time = started_at
  179. with Session(db.engine) as session:
  180. total_tenant_count = session.query(Tenant.id).count()
  181. click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
  182. handled_tenant_count = 0
  183. thread_pool = ThreadPoolExecutor(max_workers=10)
  184. def process_tenant(flask_app: Flask, tenant_id: str) -> None:
  185. try:
  186. if (
  187. not dify_config.BILLING_ENABLED
  188. or BillingService.get_info(tenant_id)["subscription"]["plan"] == "sandbox"
  189. ):
  190. # only process sandbox tenant
  191. cls.process_tenant(flask_app, tenant_id, days, batch)
  192. except Exception:
  193. logger.exception(f"Failed to process tenant {tenant_id}")
  194. finally:
  195. nonlocal handled_tenant_count
  196. handled_tenant_count += 1
  197. if handled_tenant_count % 100 == 0:
  198. click.echo(
  199. click.style(
  200. f"[{datetime.datetime.now()}] "
  201. f"Processed {handled_tenant_count} tenants "
  202. f"({(handled_tenant_count / total_tenant_count) * 100:.1f}%), "
  203. f"{handled_tenant_count}/{total_tenant_count}",
  204. fg="green",
  205. )
  206. )
  207. futures = []
  208. if tenant_ids:
  209. for tenant_id in tenant_ids:
  210. futures.append(
  211. thread_pool.submit(
  212. process_tenant,
  213. current_app._get_current_object(), # type: ignore[attr-defined]
  214. tenant_id,
  215. )
  216. )
  217. else:
  218. while current_time < ended_at:
  219. click.echo(
  220. click.style(f"Current time: {current_time}, Started at: {datetime.datetime.now()}", fg="white")
  221. )
  222. # Initial interval of 1 day, will be dynamically adjusted based on tenant count
  223. interval = datetime.timedelta(days=1)
  224. # Process tenants in this batch
  225. with Session(db.engine) as session:
  226. # Calculate tenant count in next batch with current interval
  227. # Try different intervals until we find one with a reasonable tenant count
  228. test_intervals = [
  229. datetime.timedelta(days=1),
  230. datetime.timedelta(hours=12),
  231. datetime.timedelta(hours=6),
  232. datetime.timedelta(hours=3),
  233. datetime.timedelta(hours=1),
  234. ]
  235. for test_interval in test_intervals:
  236. tenant_count = (
  237. session.query(Tenant.id)
  238. .filter(Tenant.created_at.between(current_time, current_time + test_interval))
  239. .count()
  240. )
  241. if tenant_count <= 100:
  242. interval = test_interval
  243. break
  244. else:
  245. # If all intervals have too many tenants, use minimum interval
  246. interval = datetime.timedelta(hours=1)
  247. # Adjust interval to target ~100 tenants per batch
  248. if tenant_count > 0:
  249. # Scale interval based on ratio to target count
  250. interval = min(
  251. datetime.timedelta(days=1), # Max 1 day
  252. max(
  253. datetime.timedelta(hours=1), # Min 1 hour
  254. interval * (100 / tenant_count), # Scale to target 100
  255. ),
  256. )
  257. batch_end = min(current_time + interval, ended_at)
  258. rs = (
  259. session.query(Tenant.id)
  260. .filter(Tenant.created_at.between(current_time, batch_end))
  261. .order_by(Tenant.created_at)
  262. )
  263. tenants = []
  264. for row in rs:
  265. tenant_id = str(row.id)
  266. try:
  267. tenants.append(tenant_id)
  268. except Exception:
  269. logger.exception(f"Failed to process tenant {tenant_id}")
  270. continue
  271. futures.append(
  272. thread_pool.submit(
  273. process_tenant,
  274. current_app._get_current_object(), # type: ignore[attr-defined]
  275. tenant_id,
  276. )
  277. )
  278. current_time = batch_end
  279. # wait for all threads to finish
  280. for future in futures:
  281. future.result()