Browse Source

Use typing.Literal to replace str places (#24099)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
tags/1.8.0
Zhehao Peng 2 months ago
parent
commit
c0702aacac
No account linked to committer's email address

+ 3
- 3
api/controllers/console/app/annotation.py View File

from typing import Literal

from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, marshal, marshal_with, reqparse from flask_restful import Resource, marshal, marshal_with, reqparse
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
def post(self, app_id, action):
def post(self, app_id, action: Literal["enable", "disable"]):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()


result = AppAnnotationService.enable_app_annotation(args, app_id) result = AppAnnotationService.enable_app_annotation(args, app_id)
elif action == "disable": elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_id) result = AppAnnotationService.disable_app_annotation(app_id)
else:
raise ValueError("Unsupported annotation reply action")
return result, 200 return result, 200





+ 3
- 5
api/controllers/console/datasets/datasets_document.py View File

import logging import logging
from argparse import ArgumentTypeError from argparse import ArgumentTypeError
from typing import cast
from typing import Literal, cast


from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action):
def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) document_id = str(document_id)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
document.paused_at = None document.paused_at = None
document.is_paused = False document.is_paused = False
db.session.commit() db.session.commit()
else:
raise InvalidActionError()


return {"result": "success"}, 200 return {"result": "success"}, 200


@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, action):
def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if dataset is None: if dataset is None:

+ 3
- 1
api/controllers/console/datasets/metadata.py View File

from typing import Literal

from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
@login_required @login_required
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
def post(self, dataset_id, action):
def post(self, dataset_id, action: Literal["enable", "disable"]):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:

+ 3
- 3
api/controllers/service_api/app/annotation.py View File

from typing import Literal

from flask import request from flask import request
from flask_restful import Resource, marshal, marshal_with, reqparse from flask_restful import Resource, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden


class AnnotationReplyActionApi(Resource): class AnnotationReplyActionApi(Resource):
@validate_app_token @validate_app_token
def post(self, app_model: App, action):
def post(self, app_model: App, action: Literal["enable", "disable"]):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json") parser.add_argument("score_threshold", required=True, type=float, location="json")
parser.add_argument("embedding_provider_name", required=True, type=str, location="json") parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
result = AppAnnotationService.enable_app_annotation(args, app_model.id) result = AppAnnotationService.enable_app_annotation(args, app_model.id)
elif action == "disable": elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_model.id) result = AppAnnotationService.disable_app_annotation(app_model.id)
else:
raise ValueError("Unsupported annotation reply action")
return result, 200 return result, 200





+ 4
- 2
api/controllers/service_api/dataset/dataset.py View File

from typing import Literal

from flask import request from flask import request
from flask_restful import marshal, marshal_with, reqparse from flask_restful import marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
class DocumentStatusApi(DatasetApiResource): class DocumentStatusApi(DatasetApiResource):
"""Resource for batch document status operations.""" """Resource for batch document status operations."""


def patch(self, tenant_id, dataset_id, action):
def patch(self, tenant_id, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
""" """
Batch update document status. Batch update document status.


Args: Args:
tenant_id: tenant id tenant_id: tenant id
dataset_id: dataset id dataset_id: dataset id
action: action to perform (enable, disable, archive, un_archive)
action: action to perform (Literal["enable", "disable", "archive", "un_archive"])


Returns: Returns:
dict: A dictionary with a key 'result' and a value 'success' dict: A dictionary with a key 'result' and a value 'success'

+ 3
- 1
api/controllers/service_api/dataset/metadata.py View File

from typing import Literal

from flask_login import current_user # type: ignore from flask_login import current_user # type: ignore
from flask_restful import marshal, reqparse from flask_restful import marshal, reqparse
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound


class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, action):
def post(self, tenant_id, dataset_id, action: Literal["enable", "disable"]):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:

+ 13
- 10
api/services/dataset_service.py View File

import time import time
import uuid import uuid
from collections import Counter from collections import Counter
from typing import Any, Optional
from typing import Any, Literal, Optional


from flask_login import current_user from flask_login import current_user
from sqlalchemy import func, select from sqlalchemy import func, select
RetrievalModel, RetrievalModel,
SegmentUpdateArgs, SegmentUpdateArgs,
) )
from services.errors.account import InvalidActionError, NoPermissionError
from services.errors.account import NoPermissionError
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
from services.errors.dataset import DatasetNameDuplicateError from services.errors.dataset import DatasetNameDuplicateError
from services.errors.document import DocumentIndexingError from services.errors.document import DocumentIndexingError
raise ValueError("Process rule segmentation max_tokens is invalid") raise ValueError("Process rule segmentation max_tokens is invalid")


@staticmethod @staticmethod
def batch_update_document_status(dataset: Dataset, document_ids: list[str], action: str, user):
def batch_update_document_status(
dataset: Dataset, document_ids: list[str], action: Literal["enable", "disable", "archive", "un_archive"], user
):
""" """
Batch update document status. Batch update document status.


Args: Args:
dataset (Dataset): The dataset object dataset (Dataset): The dataset object
document_ids (list[str]): List of document IDs to update document_ids (list[str]): List of document IDs to update
action (str): Action to perform (enable, disable, archive, un_archive)
action (Literal["enable", "disable", "archive", "un_archive"]): Action to perform
user: Current user performing the action user: Current user performing the action


Raises: Raises:
raise propagation_error raise propagation_error


@staticmethod @staticmethod
def _prepare_document_status_update(document, action: str, user):
"""
Prepare document status update information.
def _prepare_document_status_update(
document: Document, action: Literal["enable", "disable", "archive", "un_archive"], user
):
"""Prepare document status update information.


Args: Args:
document: Document object to update document: Document object to update
db.session.commit() db.session.commit()


@classmethod @classmethod
def update_segments_status(cls, segment_ids: list, action: str, dataset: Dataset, document: Document):
def update_segments_status(
cls, segment_ids: list, action: Literal["enable", "disable"], dataset: Dataset, document: Document
):
# Check if segment_ids is not empty to avoid WHERE false condition # Check if segment_ids is not empty to avoid WHERE false condition
if not segment_ids or len(segment_ids) == 0: if not segment_ids or len(segment_ids) == 0:
return return
db.session.commit() db.session.commit()


disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id) disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
else:
raise InvalidActionError()


@classmethod @classmethod
def create_child_chunk( def create_child_chunk(

+ 2
- 1
api/tasks/deal_dataset_vector_index_task.py View File

import logging import logging
import time import time
from typing import Literal


import click import click
from celery import shared_task # type: ignore from celery import shared_task # type: ignore




@shared_task(queue="dataset") @shared_task(queue="dataset")
def deal_dataset_vector_index_task(dataset_id: str, action: str):
def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "add", "update"]):
""" """
Async deal dataset from index Async deal dataset from index
:param dataset_id: dataset_id :param dataset_id: dataset_id

Loading…
Cancel
Save