Browse Source

try ast-grep (#24149)

tags/1.8.0
Asuka Minato 2 months ago
parent
commit
70da81d0e5
No account linked to committer's email address

+ 3
- 0
.github/workflows/autofix.yml View File

uv run ruff check --fix-only . uv run ruff check --fix-only .
# Format code # Format code
uv run ruff format . uv run ruff format .
- name: ast-grep
run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all


- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27 - uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27



+ 1
- 1
api/controllers/console/app/generator.py View File

from models import App, db from models import App, db
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService


app = db.session.query(App).filter(App.id == args["flow_id"]).first()
app = db.session.query(App).where(App.id == args["flow_id"]).first()
if not app: if not app:
return {"error": f"app {args['flow_id']} not found"}, 400 return {"error": f"app {args['flow_id']} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app) workflow = WorkflowService().get_draft_workflow(app_model=app)

+ 1
- 1
api/controllers/console/datasets/upload_file.py View File

data_source_info = document.data_source_info_dict data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info: if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"] file_id = data_source_info["upload_file_id"]
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file: if not upload_file:
raise NotFound("UploadFile not found.") raise NotFound("UploadFile not found.")
else: else:

+ 1
- 1
api/core/app/task_pipeline/message_cycle_manager.py View File

:param message_id: message id :param message_id: message id
:return: :return:
""" """
message_file = db.session.query(MessageFile).filter(MessageFile.id == message_id).first()
message_file = db.session.query(MessageFile).where(MessageFile.id == message_id).first()
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE


return MessageStreamResponse( return MessageStreamResponse(

+ 3
- 3
api/core/llm_generator/llm_generator.py View File

def instruction_modify_legacy( def instruction_modify_legacy(
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
) -> dict: ) -> dict:
app: App | None = db.session.query(App).filter(App.id == flow_id).first()
app: App | None = db.session.query(App).where(App.id == flow_id).first()
last_run: Message | None = ( last_run: Message | None = (
db.session.query(Message).filter(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
) )
if not last_run: if not last_run:
return LLMGenerator.__instruction_modify_common( return LLMGenerator.__instruction_modify_common(
) -> dict: ) -> dict:
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService


app: App | None = db.session.query(App).filter(App.id == flow_id).first()
app: App | None = db.session.query(App).where(App.id == flow_id).first()
if not app: if not app:
raise ValueError("App not found.") raise ValueError("App not found.")
workflow = WorkflowService().get_draft_workflow(app_model=app) workflow = WorkflowService().get_draft_workflow(app_model=app)

+ 16
- 16
api/schedule/clean_workflow_runlogs_precise.py View File

cutoff_date = datetime.datetime.now() - datetime.timedelta(days=retention_days) cutoff_date = datetime.datetime.now() - datetime.timedelta(days=retention_days)


try: try:
total_workflow_runs = db.session.query(WorkflowRun).filter(WorkflowRun.created_at < cutoff_date).count()
total_workflow_runs = db.session.query(WorkflowRun).where(WorkflowRun.created_at < cutoff_date).count()
if total_workflow_runs == 0: if total_workflow_runs == 0:
_logger.info("No expired workflow run logs found") _logger.info("No expired workflow run logs found")
return return


while True: while True:
workflow_runs = ( workflow_runs = (
db.session.query(WorkflowRun.id).filter(WorkflowRun.created_at < cutoff_date).limit(BATCH_SIZE).all()
db.session.query(WorkflowRun.id).where(WorkflowRun.created_at < cutoff_date).limit(BATCH_SIZE).all()
) )


if not workflow_runs: if not workflow_runs:
message_id_list = [msg.id for msg in message_data] message_id_list = [msg.id for msg in message_data]
conversation_id_list = list({msg.conversation_id for msg in message_data if msg.conversation_id}) conversation_id_list = list({msg.conversation_id for msg in message_data if msg.conversation_id})
if message_id_list: if message_id_list:
db.session.query(AppAnnotationHitHistory).filter(
db.session.query(AppAnnotationHitHistory).where(
AppAnnotationHitHistory.message_id.in_(message_id_list) AppAnnotationHitHistory.message_id.in_(message_id_list)
).delete(synchronize_session=False) ).delete(synchronize_session=False)


db.session.query(MessageAgentThought).filter(
MessageAgentThought.message_id.in_(message_id_list)
).delete(synchronize_session=False)
db.session.query(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_id_list)).delete(
synchronize_session=False
)


db.session.query(MessageChain).filter(MessageChain.message_id.in_(message_id_list)).delete(
db.session.query(MessageChain).where(MessageChain.message_id.in_(message_id_list)).delete(
synchronize_session=False synchronize_session=False
) )


db.session.query(MessageFile).filter(MessageFile.message_id.in_(message_id_list)).delete(
db.session.query(MessageFile).where(MessageFile.message_id.in_(message_id_list)).delete(
synchronize_session=False synchronize_session=False
) )


db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id.in_(message_id_list)).delete(
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_id_list)).delete(
synchronize_session=False synchronize_session=False
) )


db.session.query(MessageFeedback).filter(MessageFeedback.message_id.in_(message_id_list)).delete(
db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_id_list)).delete(
synchronize_session=False synchronize_session=False
) )


db.session.query(Message).filter(Message.workflow_run_id.in_(workflow_run_ids)).delete(
db.session.query(Message).where(Message.workflow_run_id.in_(workflow_run_ids)).delete(
synchronize_session=False synchronize_session=False
) )


db.session.query(WorkflowAppLog).filter(WorkflowAppLog.workflow_run_id.in_(workflow_run_ids)).delete(
db.session.query(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(workflow_run_ids)).delete(
synchronize_session=False synchronize_session=False
) )


db.session.query(WorkflowNodeExecutionModel).filter(
db.session.query(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.workflow_run_id.in_(workflow_run_ids) WorkflowNodeExecutionModel.workflow_run_id.in_(workflow_run_ids)
).delete(synchronize_session=False) ).delete(synchronize_session=False)


if conversation_id_list: if conversation_id_list:
db.session.query(ConversationVariable).filter(
db.session.query(ConversationVariable).where(
ConversationVariable.conversation_id.in_(conversation_id_list) ConversationVariable.conversation_id.in_(conversation_id_list)
).delete(synchronize_session=False) ).delete(synchronize_session=False)


db.session.query(Conversation).filter(Conversation.id.in_(conversation_id_list)).delete(
db.session.query(Conversation).where(Conversation.id.in_(conversation_id_list)).delete(
synchronize_session=False synchronize_session=False
) )


db.session.query(WorkflowRun).filter(WorkflowRun.id.in_(workflow_run_ids)).delete(synchronize_session=False)
db.session.query(WorkflowRun).where(WorkflowRun.id.in_(workflow_run_ids)).delete(synchronize_session=False)


db.session.commit() db.session.commit()
return True return True

+ 4
- 4
api/services/annotation_service.py View File

annotation_ids_to_delete = [annotation.id for annotation, _ in annotations_to_delete] annotation_ids_to_delete = [annotation.id for annotation, _ in annotations_to_delete]


# Step 2: Bulk delete hit histories in a single query # Step 2: Bulk delete hit histories in a single query
db.session.query(AppAnnotationHitHistory).filter(
db.session.query(AppAnnotationHitHistory).where(
AppAnnotationHitHistory.annotation_id.in_(annotation_ids_to_delete) AppAnnotationHitHistory.annotation_id.in_(annotation_ids_to_delete)
).delete(synchronize_session=False) ).delete(synchronize_session=False)


# Step 4: Bulk delete annotations in a single query # Step 4: Bulk delete annotations in a single query
deleted_count = ( deleted_count = (
db.session.query(MessageAnnotation) db.session.query(MessageAnnotation)
.filter(MessageAnnotation.id.in_(annotation_ids_to_delete))
.where(MessageAnnotation.id.in_(annotation_ids_to_delete))
.delete(synchronize_session=False) .delete(synchronize_session=False)
) )


db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
) )


annotations_query = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id)
annotations_query = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id)
for annotation in annotations_query.yield_per(100): for annotation in annotations_query.yield_per(100):
annotation_hit_histories_query = db.session.query(AppAnnotationHitHistory).filter(
annotation_hit_histories_query = db.session.query(AppAnnotationHitHistory).where(
AppAnnotationHitHistory.annotation_id == annotation.id AppAnnotationHitHistory.annotation_id == annotation.id
) )
for annotation_hit_history in annotation_hit_histories_query.yield_per(100): for annotation_hit_history in annotation_hit_histories_query.yield_per(100):

+ 2
- 2
api/tests/test_containers_integration_tests/services/test_annotation_service.py View File

# Verify annotation was deleted # Verify annotation was deleted
from extensions.ext_database import db from extensions.ext_database import db


deleted_annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
assert deleted_annotation is None assert deleted_annotation is None


# Verify delete_annotation_index_task was called (when annotation setting exists) # Verify delete_annotation_index_task was called (when annotation setting exists)
AppAnnotationService.delete_app_annotation(app.id, annotation_id) AppAnnotationService.delete_app_annotation(app.id, annotation_id)


# Verify annotation was deleted # Verify annotation was deleted
deleted_annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
assert deleted_annotation is None assert deleted_annotation is None


# Verify delete_annotation_index_task was called # Verify delete_annotation_index_task was called

+ 1
- 1
api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py View File

# Verify extension was deleted # Verify extension was deleted
from extensions.ext_database import db from extensions.ext_database import db


deleted_extension = db.session.query(APIBasedExtension).filter(APIBasedExtension.id == extension_id).first()
deleted_extension = db.session.query(APIBasedExtension).where(APIBasedExtension.id == extension_id).first()
assert deleted_extension is None assert deleted_extension is None


def test_save_extension_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): def test_save_extension_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):

+ 1
- 1
api/tests/test_containers_integration_tests/services/test_message_service.py View File

# Verify feedback was deleted # Verify feedback was deleted
from extensions.ext_database import db from extensions.ext_database import db


deleted_feedback = db.session.query(MessageFeedback).filter(MessageFeedback.id == feedback.id).first()
deleted_feedback = db.session.query(MessageFeedback).where(MessageFeedback.id == feedback.id).first()
assert deleted_feedback is None assert deleted_feedback is None


def test_create_feedback_no_rating_when_not_exists( def test_create_feedback_no_rating_when_not_exists(

+ 1
- 1
api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py View File



# Verify inherit config was created in database # Verify inherit config was created in database
inherit_configs = ( inherit_configs = (
db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.name == "__inherit__").all()
db.session.query(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__").all()
) )
assert len(inherit_configs) == 1 assert len(inherit_configs) == 1

Loading…
Cancel
Save