You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. #
  2. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import boto3
  18. from botocore.exceptions import ClientError
  19. from botocore.config import Config
  20. import time
  21. from io import BytesIO
  22. from rag.utils import singleton
  23. from rag import settings
  24. @singleton
  25. class RAGFlowS3:
  26. def __init__(self):
  27. self.conn = None
  28. self.s3_config = settings.S3
  29. self.access_key = self.s3_config.get('access_key', None)
  30. self.secret_key = self.s3_config.get('secret_key', None)
  31. self.session_token = self.s3_config.get('session_token', None)
  32. self.region_name = self.s3_config.get('region_name', None)
  33. self.endpoint_url = self.s3_config.get('endpoint_url', None)
  34. self.signature_version = self.s3_config.get('signature_version', None)
  35. self.addressing_style = self.s3_config.get('addressing_style', None)
  36. self.bucket = self.s3_config.get('bucket', None)
  37. self.prefix_path = self.s3_config.get('prefix_path', None)
  38. self.__open__()
  39. @staticmethod
  40. def use_default_bucket(method):
  41. def wrapper(self, bucket, *args, **kwargs):
  42. # If there is a default bucket, use the default bucket
  43. actual_bucket = self.bucket if self.bucket else bucket
  44. return method(self, actual_bucket, *args, **kwargs)
  45. return wrapper
  46. @staticmethod
  47. def use_prefix_path(method):
  48. def wrapper(self, bucket, fnm, *args, **kwargs):
  49. # If the prefix path is set, use the prefix path.
  50. # The bucket passed from the upstream call is
  51. # used as the file prefix. This is especially useful when you're using the default bucket
  52. if self.prefix_path:
  53. fnm = f"{self.prefix_path}/{bucket}/{fnm}"
  54. return method(self, bucket, fnm, *args, **kwargs)
  55. return wrapper
  56. def __open__(self):
  57. try:
  58. if self.conn:
  59. self.__close__()
  60. except Exception:
  61. pass
  62. try:
  63. s3_params = {}
  64. config_kwargs = {}
  65. # if not set ak/sk, boto3 s3 client would try several ways to do the authentication
  66. # see doc: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html#configuring-credentials
  67. if self.access_key and self.secret_key:
  68. s3_params = {
  69. 'aws_access_key_id': self.access_key,
  70. 'aws_secret_access_key': self.secret_key,
  71. 'aws_session_token': self.session_token,
  72. }
  73. if self.region_name:
  74. s3_params['region_name'] = self.region_name
  75. if self.endpoint_url:
  76. s3_params['endpoint_url'] = self.endpoint_url
  77. if self.signature_version:
  78. s3_params['signature_version'] = self.signature_version
  79. if self.addressing_style:
  80. s3_params['addressing_style'] = self.addressing_style
  81. if config_kwargs:
  82. s3_params['config'] = Config(**config_kwargs)
  83. self.conn = [boto3.client('s3', **s3_params)]
  84. except Exception:
  85. logging.exception(f"Fail to connect at region {self.region_name} or endpoint {self.endpoint_url}")
  86. def __close__(self):
  87. del self.conn[0]
  88. self.conn = None
  89. @use_default_bucket
  90. def bucket_exists(self, bucket, *args, **kwargs):
  91. try:
  92. logging.debug(f"head_bucket bucketname {bucket}")
  93. self.conn[0].head_bucket(Bucket=bucket)
  94. exists = True
  95. except ClientError:
  96. logging.exception(f"head_bucket error {bucket}")
  97. exists = False
  98. return exists
  99. def health(self):
  100. bucket = self.bucket
  101. fnm = "txtxtxtxt1"
  102. fnm, binary = f"{self.prefix_path}/{fnm}" if self.prefix_path else fnm, b"_t@@@1"
  103. if not self.bucket_exists(bucket):
  104. self.conn[0].create_bucket(Bucket=bucket)
  105. logging.debug(f"create bucket {bucket} ********")
  106. r = self.conn[0].upload_fileobj(BytesIO(binary), bucket, fnm)
  107. return r
  108. def get_properties(self, bucket, key):
  109. return {}
  110. def list(self, bucket, dir, recursive=True):
  111. return []
  112. @use_prefix_path
  113. @use_default_bucket
  114. def put(self, bucket, fnm, binary, *args, **kwargs):
  115. logging.debug(f"bucket name {bucket}; filename :{fnm}:")
  116. for _ in range(1):
  117. try:
  118. if not self.bucket_exists(bucket):
  119. self.conn[0].create_bucket(Bucket=bucket)
  120. logging.info(f"create bucket {bucket} ********")
  121. r = self.conn[0].upload_fileobj(BytesIO(binary), bucket, fnm)
  122. return r
  123. except Exception:
  124. logging.exception(f"Fail put {bucket}/{fnm}")
  125. self.__open__()
  126. time.sleep(1)
  127. @use_prefix_path
  128. @use_default_bucket
  129. def rm(self, bucket, fnm, *args, **kwargs):
  130. try:
  131. self.conn[0].delete_object(Bucket=bucket, Key=fnm)
  132. except Exception:
  133. logging.exception(f"Fail rm {bucket}/{fnm}")
  134. @use_prefix_path
  135. @use_default_bucket
  136. def get(self, bucket, fnm, *args, **kwargs):
  137. for _ in range(1):
  138. try:
  139. r = self.conn[0].get_object(Bucket=bucket, Key=fnm)
  140. object_data = r['Body'].read()
  141. return object_data
  142. except Exception:
  143. logging.exception(f"fail get {bucket}/{fnm}")
  144. self.__open__()
  145. time.sleep(1)
  146. return
  147. @use_prefix_path
  148. @use_default_bucket
  149. def obj_exist(self, bucket, fnm, *args, **kwargs):
  150. try:
  151. if self.conn[0].head_object(Bucket=bucket, Key=fnm):
  152. return True
  153. except ClientError as e:
  154. if e.response['Error']['Code'] == '404':
  155. return False
  156. else:
  157. raise
  158. @use_prefix_path
  159. @use_default_bucket
  160. def get_presigned_url(self, bucket, fnm, expires, *args, **kwargs):
  161. for _ in range(10):
  162. try:
  163. r = self.conn[0].generate_presigned_url('get_object',
  164. Params={'Bucket': bucket,
  165. 'Key': fnm},
  166. ExpiresIn=expires)
  167. return r
  168. except Exception:
  169. logging.exception(f"fail get url {bucket}/{fnm}")
  170. self.__open__()
  171. time.sleep(1)
  172. return