Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

s3_conn.py 6.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. #
  2. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import boto3
  18. from botocore.exceptions import ClientError
  19. from botocore.config import Config
  20. import time
  21. from io import BytesIO
  22. from rag.utils import singleton
  23. from rag import settings
  24. @singleton
  25. class RAGFlowS3:
  26. def __init__(self):
  27. self.conn = None
  28. self.s3_config = settings.S3
  29. self.access_key = self.s3_config.get('access_key', None)
  30. self.secret_key = self.s3_config.get('secret_key', None)
  31. self.region = self.s3_config.get('region', None)
  32. self.endpoint_url = self.s3_config.get('endpoint_url', None)
  33. self.signature_version = self.s3_config.get('signature_version', None)
  34. self.addressing_style = self.s3_config.get('addressing_style', None)
  35. self.bucket = self.s3_config.get('bucket', None)
  36. self.prefix_path = self.s3_config.get('prefix_path', None)
  37. self.__open__()
  38. @staticmethod
  39. def use_default_bucket(method):
  40. def wrapper(self, bucket, *args, **kwargs):
  41. # If there is a default bucket, use the default bucket
  42. actual_bucket = self.bucket if self.bucket else bucket
  43. return method(self, actual_bucket, *args, **kwargs)
  44. return wrapper
  45. @staticmethod
  46. def use_prefix_path(method):
  47. def wrapper(self, bucket, fnm, *args, **kwargs):
  48. # If the prefix path is set, use the prefix path.
  49. # The bucket passed from the upstream call is
  50. # used as the file prefix. This is especially useful when you're using the default bucket
  51. if self.prefix_path:
  52. fnm = f"{self.prefix_path}/{bucket}/{fnm}"
  53. else:
  54. fnm = f"{bucket}/{fnm}"
  55. return method(self, bucket, fnm, *args, **kwargs)
  56. return wrapper
  57. def __open__(self):
  58. try:
  59. if self.conn:
  60. self.__close__()
  61. except Exception:
  62. pass
  63. try:
  64. s3_params = {
  65. 'aws_access_key_id': self.access_key,
  66. 'aws_secret_access_key': self.secret_key,
  67. }
  68. if self.region in self.s3_config:
  69. s3_params['region_name'] = self.region
  70. if 'endpoint_url' in self.s3_config:
  71. s3_params['endpoint_url'] = self.endpoint_url
  72. if 'signature_version' in self.s3_config:
  73. s3_params['config'] = Config(s3={"signature_version": self.signature_version})
  74. if 'addressing_style' in self.s3_config:
  75. s3_params['config'] = Config(s3={"addressing_style": self.addressing_style})
  76. self.conn = boto3.client('s3', **s3_params)
  77. except Exception:
  78. logging.exception(f"Fail to connect at region {self.region} or endpoint {self.endpoint_url}")
  79. def __close__(self):
  80. del self.conn
  81. self.conn = None
  82. @use_default_bucket
  83. def bucket_exists(self, bucket):
  84. try:
  85. logging.debug(f"head_bucket bucketname {bucket}")
  86. self.conn.head_bucket(Bucket=bucket)
  87. exists = True
  88. except ClientError:
  89. logging.exception(f"head_bucket error {bucket}")
  90. exists = False
  91. return exists
  92. def health(self):
  93. bucket = self.bucket
  94. fnm = "txtxtxtxt1"
  95. fnm, binary = f"{self.prefix_path}/{fnm}" if self.prefix_path else fnm, b"_t@@@1"
  96. if not self.bucket_exists(bucket):
  97. self.conn.create_bucket(Bucket=bucket)
  98. logging.debug(f"create bucket {bucket} ********")
  99. r = self.conn.upload_fileobj(BytesIO(binary), bucket, fnm)
  100. return r
  101. def get_properties(self, bucket, key):
  102. return {}
  103. def list(self, bucket, dir, recursive=True):
  104. return []
  105. @use_prefix_path
  106. @use_default_bucket
  107. def put(self, bucket, fnm, binary):
  108. logging.debug(f"bucket name {bucket}; filename :{fnm}:")
  109. for _ in range(1):
  110. try:
  111. if not self.bucket_exists(bucket):
  112. self.conn.create_bucket(Bucket=bucket)
  113. logging.info(f"create bucket {bucket} ********")
  114. r = self.conn.upload_fileobj(BytesIO(binary), bucket, fnm)
  115. return r
  116. except Exception:
  117. logging.exception(f"Fail put {bucket}/{fnm}")
  118. self.__open__()
  119. time.sleep(1)
  120. @use_prefix_path
  121. @use_default_bucket
  122. def rm(self, bucket, fnm):
  123. try:
  124. self.conn.delete_object(Bucket=bucket, Key=fnm)
  125. except Exception:
  126. logging.exception(f"Fail rm {bucket}/{fnm}")
  127. @use_prefix_path
  128. @use_default_bucket
  129. def get(self, bucket, fnm):
  130. for _ in range(1):
  131. try:
  132. r = self.conn.get_object(Bucket=bucket, Key=fnm)
  133. object_data = r['Body'].read()
  134. return object_data
  135. except Exception:
  136. logging.exception(f"fail get {bucket}/{fnm}")
  137. self.__open__()
  138. time.sleep(1)
  139. return
  140. @use_prefix_path
  141. @use_default_bucket
  142. def obj_exist(self, bucket, fnm):
  143. try:
  144. if self.conn.head_object(Bucket=bucket, Key=fnm):
  145. return True
  146. except ClientError as e:
  147. if e.response['Error']['Code'] == '404':
  148. return False
  149. else:
  150. raise
  151. @use_prefix_path
  152. @use_default_bucket
  153. def get_presigned_url(self, bucket, fnm, expires):
  154. for _ in range(10):
  155. try:
  156. r = self.conn.generate_presigned_url('get_object',
  157. Params={'Bucket': bucket,
  158. 'Key': fnm},
  159. ExpiresIn=expires)
  160. return r
  161. except Exception:
  162. logging.exception(f"fail get url {bucket}/{fnm}")
  163. self.__open__()
  164. time.sleep(1)
  165. return