You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

oss_conn.py 5.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. #
  2. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. import boto3
  17. from botocore.exceptions import ClientError
  18. from botocore.config import Config
  19. import time
  20. from io import BytesIO
  21. from rag.utils import singleton
  22. from rag import settings
  23. @singleton
  24. class RAGFlowOSS(object):
  25. def __init__(self):
  26. self.conn = None
  27. self.oss_config = settings.OSS
  28. self.access_key = self.oss_config.get('access_key', None)
  29. self.secret_key = self.oss_config.get('secret_key', None)
  30. self.endpoint_url = self.oss_config.get('endpoint_url', None)
  31. self.region = self.oss_config.get('region', None)
  32. self.bucket = self.oss_config.get('bucket', None)
  33. self.__open__()
  34. @staticmethod
  35. def use_default_bucket(method):
  36. def wrapper(self, bucket, *args, **kwargs):
  37. # If there is a default bucket, use the default bucket
  38. actual_bucket = self.bucket if self.bucket else bucket
  39. return method(self, actual_bucket, *args, **kwargs)
  40. return wrapper
  41. def __open__(self):
  42. try:
  43. if self.conn:
  44. self.__close__()
  45. except Exception:
  46. pass
  47. try:
  48. # Reference:https://help.aliyun.com/zh/oss/developer-reference/use-amazon-s3-sdks-to-access-oss
  49. self.conn = boto3.client(
  50. 's3',
  51. region_name=self.region,
  52. aws_access_key_id=self.access_key,
  53. aws_secret_access_key=self.secret_key,
  54. endpoint_url=self.endpoint_url,
  55. config=Config(s3={"addressing_style": "virtual"}, signature_version='v4')
  56. )
  57. except Exception:
  58. logging.exception(f"Fail to connect at region {self.region}")
  59. def __close__(self):
  60. del self.conn
  61. self.conn = None
  62. @use_default_bucket
  63. def bucket_exists(self, bucket):
  64. try:
  65. logging.debug(f"head_bucket bucketname {bucket}")
  66. self.conn.head_bucket(Bucket=bucket)
  67. exists = True
  68. except ClientError:
  69. logging.exception(f"head_bucket error {bucket}")
  70. exists = False
  71. return exists
  72. def health(self):
  73. bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
  74. if not self.bucket_exists(bucket):
  75. self.conn.create_bucket(Bucket=bucket)
  76. logging.debug(f"create bucket {bucket} ********")
  77. r = self.conn.upload_fileobj(BytesIO(binary), bucket, fnm)
  78. return r
  79. def get_properties(self, bucket, key):
  80. return {}
  81. def list(self, bucket, dir, recursive=True):
  82. return []
  83. @use_default_bucket
  84. def put(self, bucket, fnm, binary):
  85. logging.debug(f"bucket name {bucket}; filename :{fnm}:")
  86. for _ in range(1):
  87. try:
  88. if not self.bucket_exists(bucket):
  89. self.conn.create_bucket(Bucket=bucket)
  90. logging.info(f"create bucket {bucket} ********")
  91. r = self.conn.upload_fileobj(BytesIO(binary), bucket, fnm)
  92. return r
  93. except Exception:
  94. logging.exception(f"Fail put {bucket}/{fnm}")
  95. self.__open__()
  96. time.sleep(1)
  97. @use_default_bucket
  98. def rm(self, bucket, fnm):
  99. try:
  100. self.conn.delete_object(Bucket=bucket, Key=fnm)
  101. except Exception:
  102. logging.exception(f"Fail rm {bucket}/{fnm}")
  103. @use_default_bucket
  104. def get(self, bucket, fnm):
  105. for _ in range(1):
  106. try:
  107. r = self.conn.get_object(Bucket=bucket, Key=fnm)
  108. object_data = r['Body'].read()
  109. return object_data
  110. except Exception:
  111. logging.exception(f"fail get {bucket}/{fnm}")
  112. self.__open__()
  113. time.sleep(1)
  114. return
  115. @use_default_bucket
  116. def obj_exist(self, bucket, fnm):
  117. try:
  118. if self.conn.head_object(Bucket=bucket, Key=fnm):
  119. return True
  120. except ClientError as e:
  121. if e.response['Error']['Code'] == '404':
  122. return False
  123. else:
  124. raise
  125. @use_default_bucket
  126. def get_presigned_url(self, bucket, fnm, expires):
  127. for _ in range(10):
  128. try:
  129. r = self.conn.generate_presigned_url('get_object',
  130. Params={'Bucket': bucket,
  131. 'Key': fnm},
  132. ExpiresIn=expires)
  133. return r
  134. except Exception:
  135. logging.exception(f"fail get url {bucket}/{fnm}")
  136. self.__open__()
  137. time.sleep(1)
  138. return