選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

s3_conn.py 6.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. #
  2. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import boto3
  18. from botocore.exceptions import ClientError
  19. from botocore.config import Config
  20. import time
  21. from io import BytesIO
  22. from rag.utils import singleton
  23. from rag import settings
  24. @singleton
  25. class RAGFlowS3:
  26. def __init__(self):
  27. self.conn = None
  28. self.s3_config = settings.S3
  29. self.access_key = self.s3_config.get('access_key', None)
  30. self.secret_key = self.s3_config.get('secret_key', None)
  31. self.region = self.s3_config.get('region', None)
  32. self.endpoint_url = self.s3_config.get('endpoint_url', None)
  33. self.signature_version = self.s3_config.get('signature_version', None)
  34. self.addressing_style = self.s3_config.get('addressing_style', None)
  35. self.bucket = self.s3_config.get('bucket', None)
  36. self.prefix_path = self.s3_config.get('prefix_path', None)
  37. self.__open__()
  38. @staticmethod
  39. def use_default_bucket(method):
  40. def wrapper(self, bucket, *args, **kwargs):
  41. # If there is a default bucket, use the default bucket
  42. actual_bucket = self.bucket if self.bucket else bucket
  43. return method(self, actual_bucket, *args, **kwargs)
  44. return wrapper
  45. @staticmethod
  46. def use_prefix_path(method):
  47. def wrapper(self, bucket, fnm, *args, **kwargs):
  48. # If the prefix path is set, use the prefix path.
  49. # The bucket passed from the upstream call is
  50. # used as the file prefix. This is especially useful when you're using the default bucket
  51. if self.prefix_path:
  52. fnm = f"{self.prefix_path}/{bucket}/{fnm}"
  53. return method(self, bucket, fnm, *args, **kwargs)
  54. return wrapper
  55. def __open__(self):
  56. try:
  57. if self.conn:
  58. self.__close__()
  59. except Exception:
  60. pass
  61. try:
  62. s3_params = {
  63. 'aws_access_key_id': self.access_key,
  64. 'aws_secret_access_key': self.secret_key,
  65. }
  66. if self.region in self.s3_config:
  67. s3_params['region_name'] = self.region
  68. if 'endpoint_url' in self.s3_config:
  69. s3_params['endpoint_url'] = self.endpoint_url
  70. if 'signature_version' in self.s3_config:
  71. s3_params['config'] = Config(s3={"signature_version": self.signature_version})
  72. if 'addressing_style' in self.s3_config:
  73. s3_params['config'] = Config(s3={"addressing_style": self.addressing_style})
  74. self.conn = boto3.client('s3', **s3_params)
  75. except Exception:
  76. logging.exception(f"Fail to connect at region {self.region} or endpoint {self.endpoint_url}")
  77. def __close__(self):
  78. del self.conn
  79. self.conn = None
  80. @use_default_bucket
  81. def bucket_exists(self, bucket):
  82. try:
  83. logging.debug(f"head_bucket bucketname {bucket}")
  84. self.conn.head_bucket(Bucket=bucket)
  85. exists = True
  86. except ClientError:
  87. logging.exception(f"head_bucket error {bucket}")
  88. exists = False
  89. return exists
  90. def health(self):
  91. bucket = self.bucket
  92. fnm = "txtxtxtxt1"
  93. fnm, binary = f"{self.prefix_path}/{fnm}" if self.prefix_path else fnm, b"_t@@@1"
  94. if not self.bucket_exists(bucket):
  95. self.conn.create_bucket(Bucket=bucket)
  96. logging.debug(f"create bucket {bucket} ********")
  97. r = self.conn.upload_fileobj(BytesIO(binary), bucket, fnm)
  98. return r
  99. def get_properties(self, bucket, key):
  100. return {}
  101. def list(self, bucket, dir, recursive=True):
  102. return []
  103. @use_prefix_path
  104. @use_default_bucket
  105. def put(self, bucket, fnm, binary):
  106. logging.debug(f"bucket name {bucket}; filename :{fnm}:")
  107. for _ in range(1):
  108. try:
  109. if not self.bucket_exists(bucket):
  110. self.conn.create_bucket(Bucket=bucket)
  111. logging.info(f"create bucket {bucket} ********")
  112. r = self.conn.upload_fileobj(BytesIO(binary), bucket, fnm)
  113. return r
  114. except Exception:
  115. logging.exception(f"Fail put {bucket}/{fnm}")
  116. self.__open__()
  117. time.sleep(1)
  118. @use_prefix_path
  119. @use_default_bucket
  120. def rm(self, bucket, fnm):
  121. try:
  122. self.conn.delete_object(Bucket=bucket, Key=fnm)
  123. except Exception:
  124. logging.exception(f"Fail rm {bucket}/{fnm}")
  125. @use_prefix_path
  126. @use_default_bucket
  127. def get(self, bucket, fnm):
  128. for _ in range(1):
  129. try:
  130. r = self.conn.get_object(Bucket=bucket, Key=fnm)
  131. object_data = r['Body'].read()
  132. return object_data
  133. except Exception:
  134. logging.exception(f"fail get {bucket}/{fnm}")
  135. self.__open__()
  136. time.sleep(1)
  137. return
  138. @use_prefix_path
  139. @use_default_bucket
  140. def obj_exist(self, bucket, fnm):
  141. try:
  142. if self.conn.head_object(Bucket=bucket, Key=fnm):
  143. return True
  144. except ClientError as e:
  145. if e.response['Error']['Code'] == '404':
  146. return False
  147. else:
  148. raise
  149. @use_prefix_path
  150. @use_default_bucket
  151. def get_presigned_url(self, bucket, fnm, expires):
  152. for _ in range(10):
  153. try:
  154. r = self.conn.generate_presigned_url('get_object',
  155. Params={'Bucket': bucket,
  156. 'Key': fnm},
  157. ExpiresIn=expires)
  158. return r
  159. except Exception:
  160. logging.exception(f"fail get url {bucket}/{fnm}")
  161. self.__open__()
  162. time.sleep(1)
  163. return