Storage: Support the s3, azure blob as the object storage of ragflow. (#2278)

### What problem does this PR solve?

issue: https://github.com/infiniflow/ragflow/issues/2277

_Briefly describe what this PR aims to solve. Include background context
that will help reviewers understand the purpose of the PR._

### Type of change

- [ ] Bug Fix (non-breaking change which fixes an issue)
- [x] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [ ] Performance Improvement
- [ ] Other (please describe):

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
Fachuan Bai
2024-09-09 09:41:14 +08:00
committed by GitHub
parent e85fea31a8
commit 8dd3adc443
17 changed files with 395 additions and 38 deletions

View File

@ -24,6 +24,8 @@ RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
SUBPROCESS_STD_LOG_NAME = "std.log"
ES = get_base_config("es", {})
AZURE = get_base_config("azure", {})
S3 = get_base_config("s3", {})
MINIO = decrypt_database_config(name="minio")
try:
REDIS = decrypt_database_config(name="redis")
@ -43,6 +45,8 @@ LoggerFactory.LEVEL = 30
es_logger = getLogger("es")
minio_logger = getLogger("minio")
s3_logger = getLogger("s3")
azure_logger = getLogger("azure")
cron_logger = getLogger("cron_logger")
cron_logger.setLevel(20)
chunk_logger = getLogger("chunk_logger")

View File

@ -20,7 +20,7 @@ import traceback
from api.db.db_models import close_connection
from api.db.services.task_service import TaskService
from rag.settings import cron_logger
from rag.utils.minio_conn import MINIO
from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.redis_conn import REDIS_CONN
@ -42,7 +42,7 @@ def main():
try:
key = "{}/{}".format(kb_id, loc)
if REDIS_CONN.exist(key):continue
file_bin = MINIO.get(kb_id, loc)
file_bin = STORAGE_IMPL.get(kb_id, loc)
REDIS_CONN.transaction(key, file_bin, 12 * 60)
cron_logger.info("CACHE: {}".format(loc))
except Exception as e:

View File

@ -29,7 +29,7 @@ from functools import partial
from api.db.services.file2document_service import File2DocumentService
from api.settings import retrievaler
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.utils.minio_conn import MINIO
from rag.utils.storage_factory import STORAGE_IMPL
from api.db.db_models import close_connection
from rag.settings import database_logger, SVR_QUEUE_NAME
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
@ -138,7 +138,7 @@ def collect():
def get_minio_binary(bucket, name):
return MINIO.get(bucket, name)
return STORAGE_IMPL.get(bucket, name)
def build(row):
@ -214,7 +214,7 @@ def build(row):
d["image"].save(output_buffer, format='JPEG')
st = timer()
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
STORAGE_IMPL.put(row["kb_id"], d["_id"], output_buffer.getvalue())
el += timer() - st
except Exception as e:
cron_logger.error(str(e))

View File

@ -0,0 +1,80 @@
import os
import time
from io import BytesIO
from rag import settings
from rag.settings import azure_logger
from rag.utils import singleton
from azure.storage.blob import ContainerClient
@singleton
class RAGFlowAzureSasBlob(object):
def __init__(self):
self.conn = None
self.container_url = os.getenv('CONTAINER_URL', settings.AZURE["container_url"])
self.sas_token = os.getenv('SAS_TOKEN', settings.AZURE["sas_token"])
self.__open__()
def __open__(self):
try:
if self.conn:
self.__close__()
except Exception as e:
pass
try:
self.conn = ContainerClient.from_container_url(self.account_url + "?" + self.sas_token)
except Exception as e:
azure_logger.error(
"Fail to connect %s " % self.account_url + str(e))
def __close__(self):
del self.conn
self.conn = None
def health(self):
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
return self.conn.upload_blob(name=fnm, data=BytesIO(binary), length=len(binary))
def put(self, bucket, fnm, binary):
for _ in range(3):
try:
return self.conn.upload_blob(name=fnm, data=BytesIO(binary), length=len(binary))
except Exception as e:
azure_logger.error(f"Fail put {bucket}/{fnm}: " + str(e))
self.__open__()
time.sleep(1)
def rm(self, bucket, fnm):
try:
self.conn.delete_blob(fnm)
except Exception as e:
azure_logger.error(f"Fail rm {bucket}/{fnm}: " + str(e))
def get(self, bucket, fnm):
for _ in range(1):
try:
r = self.conn.download_blob(fnm)
return r.read()
except Exception as e:
azure_logger.error(f"fail get {bucket}/{fnm}: " + str(e))
self.__open__()
time.sleep(1)
return
def obj_exist(self, bucket, fnm):
try:
return self.conn.get_blob_client(fnm).exists()
except Exception as e:
azure_logger.error(f"Fail put {bucket}/{fnm}: " + str(e))
return False
def get_presigned_url(self, bucket, fnm, expires):
for _ in range(10):
try:
return self.conn.get_presigned_url("GET", bucket, fnm, expires)
except Exception as e:
azure_logger.error(f"fail get {bucket}/{fnm}: " + str(e))
self.__open__()
time.sleep(1)
return

View File

@ -0,0 +1,90 @@
import os
import time
from rag import settings
from rag.settings import azure_logger
from rag.utils import singleton
from azure.identity import ClientSecretCredential, AzureAuthorityHosts
from azure.storage.filedatalake import FileSystemClient
@singleton
class RAGFlowAzureSpnBlob(object):
def __init__(self):
self.conn = None
self.account_url = os.getenv('ACCOUNT_URL', settings.AZURE["account_url"])
self.client_id = os.getenv('CLIENT_ID', settings.AZURE["client_id"])
self.secret = os.getenv('SECRET', settings.AZURE["secret"])
self.tenant_id = os.getenv('TENANT_ID', settings.AZURE["tenant_id"])
self.container_name = os.getenv('CONTAINER_NAME', settings.AZURE["container_name"])
self.__open__()
def __open__(self):
try:
if self.conn:
self.__close__()
except Exception as e:
pass
try:
credentials = ClientSecretCredential(tenant_id=self.tenant_id, client_id=self.client_id, client_secret=self.secret, authority=AzureAuthorityHosts.AZURE_CHINA)
self.conn = FileSystemClient(account_url=self.account_url, file_system_name=self.container_name, credential=credentials)
except Exception as e:
azure_logger.error(
"Fail to connect %s " % self.account_url + str(e))
def __close__(self):
del self.conn
self.conn = None
def health(self):
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
f = self.conn.create_file(fnm)
f.append_data(binary, offset=0, length=len(binary))
return f.flush_data(len(binary))
def put(self, bucket, fnm, binary):
for _ in range(3):
try:
f = self.conn.create_file(fnm)
f.append_data(binary, offset=0, length=len(binary))
return f.flush_data(len(binary))
except Exception as e:
azure_logger.error(f"Fail put {bucket}/{fnm}: " + str(e))
self.__open__()
time.sleep(1)
def rm(self, bucket, fnm):
try:
self.conn.delete_file(fnm)
except Exception as e:
azure_logger.error(f"Fail rm {bucket}/{fnm}: " + str(e))
def get(self, bucket, fnm):
for _ in range(1):
try:
client = self.conn.get_file_client(fnm)
r = client.download_file()
return r.read()
except Exception as e:
azure_logger.error(f"fail get {bucket}/{fnm}: " + str(e))
self.__open__()
time.sleep(1)
return
def obj_exist(self, bucket, fnm):
try:
client = self.conn.get_file_client(fnm)
return client.exists()
except Exception as e:
azure_logger.error(f"Fail put {bucket}/{fnm}: " + str(e))
return False
def get_presigned_url(self, bucket, fnm, expires):
for _ in range(10):
try:
return self.conn.get_presigned_url("GET", bucket, fnm, expires)
except Exception as e:
azure_logger.error(f"fail get {bucket}/{fnm}: " + str(e))
self.__open__()
time.sleep(1)
return

135
rag/utils/s3_conn.py Normal file
View File

@ -0,0 +1,135 @@
import boto3
import os
from botocore.exceptions import ClientError
from botocore.client import Config
import time
from io import BytesIO
from rag.settings import s3_logger
from rag.utils import singleton
@singleton
class RAGFlowS3(object):
def __init__(self):
self.conn = None
self.endpoint = os.getenv('ENDPOINT', None)
self.access_key = os.getenv('ACCESS_KEY', None)
self.secret_key = os.getenv('SECRET_KEY', None)
self.region = os.getenv('REGION', None)
self.__open__()
def __open__(self):
try:
if self.conn:
self.__close__()
except Exception as e:
pass
try:
config = Config(
s3={
'addressing_style': 'virtual'
}
)
self.conn = boto3.client(
's3',
endpoint_url=self.endpoint,
region_name=self.region,
aws_access_key_id=self.access_key,
aws_secret_access_key=self.secret_key,
config=config
)
except Exception as e:
s3_logger.error(
"Fail to connect %s " % self.endpoint + str(e))
def __close__(self):
del self.conn
self.conn = None
def bucket_exists(self, bucket):
try:
s3_logger.error(f"head_bucket bucketname {bucket}")
self.conn.head_bucket(Bucket=bucket)
exists = True
except ClientError as e:
s3_logger.error(f"head_bucket error {bucket}: " + str(e))
exists = False
return exists
def health(self):
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
if not self.bucket_exists(bucket):
self.conn.create_bucket(Bucket=bucket)
s3_logger.error(f"create bucket {bucket} ********")
r = self.conn.upload_fileobj(BytesIO(binary), bucket, fnm)
return r
def get_properties(self, bucket, key):
return {}
def list(self, bucket, dir, recursive=True):
return []
def put(self, bucket, fnm, binary):
s3_logger.error(f"bucket name {bucket}; filename :{fnm}:")
for _ in range(1):
try:
if not self.bucket_exists(bucket):
self.conn.create_bucket(Bucket=bucket)
s3_logger.error(f"create bucket {bucket} ********")
r = self.conn.upload_fileobj(BytesIO(binary), bucket, fnm)
return r
except Exception as e:
s3_logger.error(f"Fail put {bucket}/{fnm}: " + str(e))
self.__open__()
time.sleep(1)
def rm(self, bucket, fnm):
try:
self.conn.delete_object(Bucket=bucket, Key=fnm)
except Exception as e:
s3_logger.error(f"Fail rm {bucket}/{fnm}: " + str(e))
def get(self, bucket, fnm):
for _ in range(1):
try:
r = self.conn.get_object(Bucket=bucket, Key=fnm)
object_data = r['Body'].read()
return object_data
except Exception as e:
s3_logger.error(f"fail get {bucket}/{fnm}: " + str(e))
self.__open__()
time.sleep(1)
return
def obj_exist(self, bucket, fnm):
try:
if self.conn.head_object(Bucket=bucket, Key=fnm):
return True
except ClientError as e:
if e.response['Error']['Code'] == '404':
return False
else:
raise
def get_presigned_url(self, bucket, fnm, expires):
for _ in range(10):
try:
r = self.conn.generate_presigned_url('get_object',
Params={'Bucket': bucket,
'Key': fnm},
ExpiresIn=expires)
return r
except Exception as e:
s3_logger.error(f"fail get url {bucket}/{fnm}: " + str(e))
self.__open__()
time.sleep(1)
return

View File

@ -0,0 +1,30 @@
import os
from enum import Enum
from rag.utils.azure_sas_conn import RAGFlowAzureSasBlob
from rag.utils.azure_spn_conn import RAGFlowAzureSpnBlob
from rag.utils.minio_conn import RAGFlowMinio
from rag.utils.s3_conn import RAGFlowS3
class Storage(Enum):
MINIO = 1
AZURE_SPN = 2
AZURE_SAS = 3
AWS_S3 = 4
class StorageFactory:
storage_mapping = {
Storage.MINIO: RAGFlowMinio,
Storage.AZURE_SPN: RAGFlowAzureSpnBlob,
Storage.AZURE_SAS: RAGFlowAzureSasBlob,
Storage.AWS_S3: RAGFlowS3,
}
@classmethod
def create(cls, storage: Storage):
return cls.storage_mapping[storage]()
STORAGE_IMPL = StorageFactory.create(Storage[os.getenv('STORAGE_IMPL', 'MINIO')])