### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)tags/v0.4.0
| @@ -45,7 +45,7 @@ def convert(): | |||
| for file_id in file_ids: | |||
| e, file = FileService.get_by_id(file_id) | |||
| file_ids_list = [file_id] | |||
| if file.type == FileType.FOLDER: | |||
| if file.type == FileType.FOLDER.value: | |||
| file_ids_list = FileService.get_all_innermost_file_ids(file_id, []) | |||
| for id in file_ids_list: | |||
| informs = File2DocumentService.get_by_file_id(id) | |||
| @@ -64,7 +64,7 @@ def upload(): | |||
| return get_data_error_result( | |||
| retmsg="Can't find this folder!") | |||
| MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0)) | |||
| if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER: | |||
| if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(current_user.id) >= MAX_FILE_NUM_PER_USER: | |||
| return get_data_error_result( | |||
| retmsg="Exceed the maximum file number of a free user!") | |||
| @@ -143,9 +143,9 @@ def create(): | |||
| retmsg="Duplicated folder name in the same folder.") | |||
| if input_file_type == FileType.FOLDER.value: | |||
| file_type = FileType.FOLDER | |||
| file_type = FileType.FOLDER.value | |||
| else: | |||
| file_type = FileType.VIRTUAL | |||
| file_type = FileType.VIRTUAL.value | |||
| file = FileService.insert({ | |||
| "id": get_uuid(), | |||
| @@ -251,7 +251,7 @@ def rm(): | |||
| if not file.tenant_id: | |||
| return get_data_error_result(retmsg="Tenant not found!") | |||
| if file.type == FileType.FOLDER: | |||
| if file.type == FileType.FOLDER.value: | |||
| file_id_list = FileService.get_all_innermost_file_ids(file_id, []) | |||
| for inner_file_id in file_id_list: | |||
| e, file = FileService.get_by_id(inner_file_id) | |||
| @@ -24,7 +24,7 @@ from api.db.db_models import TenantLLM | |||
| from api.db.services.llm_service import TenantLLMService, LLMService | |||
| from api.utils.api_utils import server_error_response, validate_request | |||
| from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format | |||
| from api.db import UserTenantRole, LLMType | |||
| from api.db import UserTenantRole, LLMType, FileType | |||
| from api.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, API_KEY, \ | |||
| LLM_FACTORY, LLM_BASE_URL | |||
| from api.db.services.user_service import UserService, TenantService, UserTenantService | |||
| @@ -229,7 +229,7 @@ def user_register(user_id, user): | |||
| "tenant_id": user_id, | |||
| "created_by": user_id, | |||
| "name": "/", | |||
| "type": FileType.FOLDER, | |||
| "type": FileType.FOLDER.value, | |||
| "size": 0, | |||
| "location": "", | |||
| } | |||
| @@ -120,7 +120,7 @@ class FileService(CommonService): | |||
| "name": name[count], | |||
| "location": "", | |||
| "size": 0, | |||
| "type": FileType.FOLDER | |||
| "type": FileType.FOLDER.value | |||
| }) | |||
| return cls.create_folder(file, file.id, name, count + 1) | |||
| @@ -138,7 +138,23 @@ class FileService(CommonService): | |||
| def get_root_folder(cls, tenant_id): | |||
| file = cls.model.select().where(cls.model.tenant_id == tenant_id and | |||
| cls.model.parent_id == cls.model.id) | |||
| e, file = cls.get_by_id(file[0].id) | |||
| if not file: | |||
| file_id = get_uuid() | |||
| file = { | |||
| "id": file_id, | |||
| "parent_id": file_id, | |||
| "tenant_id": tenant_id, | |||
| "created_by": tenant_id, | |||
| "name": "/", | |||
| "type": FileType.FOLDER.value, | |||
| "size": 0, | |||
| "location": "", | |||
| } | |||
| cls.save(**file) | |||
| else: | |||
| file_id = file[0].id | |||
| e, file = cls.get_by_id(file_id) | |||
| if not e: | |||
| raise RuntimeError("Database error (File retrieval)!") | |||
| return file | |||
| @@ -214,12 +230,14 @@ class FileService(CommonService): | |||
| @DB.connection_context() | |||
| def get_folder_size(cls, folder_id): | |||
| size = 0 | |||
| def dfs(parent_id): | |||
| nonlocal size | |||
| for f in cls.model.select(*[cls.model.id, cls.model.size, cls.model.type]).where(cls.model.parent_id == parent_id): | |||
| for f in cls.model.select(*[cls.model.id, cls.model.size, cls.model.type]).where( | |||
| cls.model.parent_id == parent_id, cls.model.id != parent_id): | |||
| size += f.size | |||
| if f.type == FileType.FOLDER.value: | |||
| dfs(f.id) | |||
| dfs(folder_id) | |||
| return size | |||
| return size | |||