From 81eb03d230bb9a738a737f46fde8ec24206d4aaa Mon Sep 17 00:00:00 2001 From: YngvarHuang <625452882@qq.com> Date: Mon, 15 Dec 2025 09:45:18 +0800 Subject: [PATCH] Support uploading encrypted files to object storage (#11837) (#11838) ### What problem does this PR solve? Support uploading encrypted files to object storage. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: virgilwong --- common/crypto_utils.py | 374 +++++++++++++++++++++++++++++++++ common/settings.py | 22 +- docker/.env | 7 +- pyproject.toml | 1 + rag/utils/encrypted_storage.py | 266 +++++++++++++++++++++++ uv.lock | 2 + 6 files changed, 669 insertions(+), 3 deletions(-) create mode 100644 common/crypto_utils.py create mode 100644 rag/utils/encrypted_storage.py diff --git a/common/crypto_utils.py b/common/crypto_utils.py new file mode 100644 index 000000000..5dcbd2937 --- /dev/null +++ b/common/crypto_utils.py @@ -0,0 +1,374 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.primitives import padding +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +from cryptography.hazmat.primitives import hashes + + +class BaseCrypto: + """Base class for cryptographic algorithms""" + + # Magic header to identify encrypted data + ENCRYPTED_MAGIC = b'RAGF' + + def __init__(self, key, iv=None, block_size=16, key_length=32, iv_length=16): + """ + Initialize cryptographic algorithm + + Args: + key: Encryption key + iv: Initialization vector, automatically generated if None + block_size: Block size + key_length: Key length + iv_length: Initialization vector length + """ + self.block_size = block_size + self.key_length = key_length + self.iv_length = iv_length + + # Normalize key + self.key = self._normalize_key(key) + self.iv = iv + + def _normalize_key(self, key): + """Normalize key length""" + if isinstance(key, str): + key = key.encode('utf-8') + + # Use PBKDF2 for key derivation to ensure correct key length + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=self.key_length, + salt=b"ragflow_crypto_salt", # Fixed salt to ensure consistent key derivation results + iterations=100000, + backend=default_backend() + ) + + return kdf.derive(key) + + def encrypt(self, data): + """ + Encrypt data (template method) + + Args: + data: Data to encrypt (bytes) + + Returns: + Encrypted data (bytes), format: magic_header + iv + encrypted_data + """ + # Generate random IV + iv = os.urandom(self.iv_length) if not self.iv else self.iv + + # Use PKCS7 padding + padder = padding.PKCS7(self.block_size * 8).padder() + padded_data = padder.update(data) + padder.finalize() + + # Delegate to subclass for specific encryption + ciphertext = self._encrypt(padded_data, iv) + + # Return Magic Header + IV + encrypted data + return self.ENCRYPTED_MAGIC + iv + ciphertext + + def decrypt(self, encrypted_data): + """ + Decrypt data (template method) + + Args: + encrypted_data: Encrypted data (bytes) + + Returns: + Decrypted data (bytes) + """ + # Check if data is encrypted by magic header + if not encrypted_data.startswith(self.ENCRYPTED_MAGIC): + # Not encrypted, return as-is + return encrypted_data + + # Remove magic header + encrypted_data = encrypted_data[len(self.ENCRYPTED_MAGIC):] + + # Separate IV and encrypted data + iv = encrypted_data[:self.iv_length] + ciphertext = encrypted_data[self.iv_length:] + + # Delegate to subclass for specific decryption + padded_data = self._decrypt(ciphertext, iv) + + # Remove padding + unpadder = padding.PKCS7(self.block_size * 8).unpadder() + data = unpadder.update(padded_data) + unpadder.finalize() + + return data + + def _encrypt(self, padded_data, iv): + """ + Encrypt padded data with specific algorithm + + Args: + padded_data: Padded data to encrypt + iv: Initialization vector + + Returns: + Encrypted data + """ + raise NotImplementedError("_encrypt method must be implemented by subclass") + + def _decrypt(self, ciphertext, iv): + """ + Decrypt ciphertext with specific algorithm + + Args: + ciphertext: Ciphertext to decrypt + iv: Initialization vector + + Returns: + Decrypted padded data + """ + raise NotImplementedError("_decrypt method must be implemented by subclass") + + +class AESCrypto(BaseCrypto): + """Base class for AES cryptographic algorithm""" + + def __init__(self, key, iv=None, key_length=32): + """ + Initialize AES cryptographic algorithm + + Args: + key: Encryption key + iv: Initialization vector, automatically generated if None + key_length: Key length (16 for AES-128, 32 for AES-256) + """ + super().__init__(key, iv, block_size=16, key_length=key_length, iv_length=16) + + def _encrypt(self, padded_data, iv): + """AES encryption implementation""" + # Create encryptor + cipher = Cipher( + algorithms.AES(self.key), + modes.CBC(iv), + backend=default_backend() + ) + encryptor = cipher.encryptor() + + # Encrypt data + return encryptor.update(padded_data) + encryptor.finalize() + + def _decrypt(self, ciphertext, iv): + """AES decryption implementation""" + # Create decryptor + cipher = Cipher( + algorithms.AES(self.key), + modes.CBC(iv), + backend=default_backend() + ) + decryptor = cipher.decryptor() + + # Decrypt data + return decryptor.update(ciphertext) + decryptor.finalize() + + +class AES128CBC(AESCrypto): + """AES-128-CBC cryptographic algorithm""" + + def __init__(self, key, iv=None): + """ + Initialize AES-128-CBC cryptographic algorithm + + Args: + key: Encryption key + iv: Initialization vector, automatically generated if None + """ + super().__init__(key, iv, key_length=16) + + +class AES256CBC(AESCrypto): + """AES-256-CBC cryptographic algorithm""" + + def __init__(self, key, iv=None): + """ + Initialize AES-256-CBC cryptographic algorithm + + Args: + key: Encryption key + iv: Initialization vector, automatically generated if None + """ + super().__init__(key, iv, key_length=32) + + +class SM4CBC(BaseCrypto): + """SM4-CBC cryptographic algorithm using cryptography library for better performance""" + + def __init__(self, key, iv=None): + """ + Initialize SM4-CBC cryptographic algorithm + + Args: + key: Encryption key + iv: Initialization vector, automatically generated if None + """ + super().__init__(key, iv, block_size=16, key_length=16, iv_length=16) + + def _encrypt(self, padded_data, iv): + """SM4 encryption implementation using cryptography library""" + # Create encryptor + cipher = Cipher( + algorithms.SM4(self.key), + modes.CBC(iv), + backend=default_backend() + ) + encryptor = cipher.encryptor() + + # Encrypt data + return encryptor.update(padded_data) + encryptor.finalize() + + def _decrypt(self, ciphertext, iv): + """SM4 decryption implementation using cryptography library""" + # Create decryptor + cipher = Cipher( + algorithms.SM4(self.key), + modes.CBC(iv), + backend=default_backend() + ) + decryptor = cipher.decryptor() + + # Decrypt data + return decryptor.update(ciphertext) + decryptor.finalize() + + +class CryptoUtil: + """Cryptographic utility class, using factory pattern to create cryptographic algorithm instances""" + + # Supported cryptographic algorithms mapping + SUPPORTED_ALGORITHMS = { + "aes-128-cbc": AES128CBC, + "aes-256-cbc": AES256CBC, + "sm4-cbc": SM4CBC + } + + def __init__(self, algorithm="aes-256-cbc", key=None, iv=None): + """ + Initialize cryptographic utility + + Args: + algorithm: Cryptographic algorithm, default is aes-256-cbc + key: Encryption key, uses RAGFLOW_CRYPTO_KEY environment variable if None + iv: Initialization vector, automatically generated if None + """ + if algorithm not in self.SUPPORTED_ALGORITHMS: + raise ValueError(f"Unsupported algorithm: {algorithm}") + + if not key: + raise ValueError("Encryption key not provided and RAGFLOW_CRYPTO_KEY environment variable not set") + + # Create cryptographic algorithm instance + self.algorithm_name = algorithm + self.crypto = self.SUPPORTED_ALGORITHMS[algorithm](key=key, iv=iv) + + def encrypt(self, data): + """ + Encrypt data + + Args: + data: Data to encrypt (bytes) + + Returns: + Encrypted data (bytes) + """ + # import time + # start_time = time.time() + encrypted = self.crypto.encrypt(data) + # end_time = time.time() + # logging.info(f"Encryption completed, data length: {len(data)} bytes, time: {(end_time - start_time)*1000:.2f} ms") + return encrypted + + def decrypt(self, encrypted_data): + """ + Decrypt data + + Args: + encrypted_data: Encrypted data (bytes) + + Returns: + Decrypted data (bytes) + """ + # import time + # start_time = time.time() + decrypted = self.crypto.decrypt(encrypted_data) + # end_time = time.time() + # logging.info(f"Decryption completed, data length: {len(encrypted_data)} bytes, time: {(end_time - start_time)*1000:.2f} ms") + return decrypted + + +# Test code +if __name__ == "__main__": + # Test AES encryption + crypto = CryptoUtil(algorithm="aes-256-cbc", key="test_key_123456") + test_data = b"Hello, RAGFlow! This is a test for encryption." + + encrypted = crypto.encrypt(test_data) + decrypted = crypto.decrypt(encrypted) + + print("AES Test:") + print(f"Original: {test_data}") + print(f"Encrypted: {encrypted}") + print(f"Decrypted: {decrypted}") + print(f"Success: {test_data == decrypted}") + print() + + # Test SM4 encryption + try: + crypto_sm4 = CryptoUtil(algorithm="sm4-cbc", key="test_key_123456") + encrypted_sm4 = crypto_sm4.encrypt(test_data) + decrypted_sm4 = crypto_sm4.decrypt(encrypted_sm4) + + print("SM4 Test:") + print(f"Original: {test_data}") + print(f"Encrypted: {encrypted_sm4}") + print(f"Decrypted: {decrypted_sm4}") + print(f"Success: {test_data == decrypted_sm4}") + except Exception as e: + print(f"SM4 Test Failed: {e}") + import traceback + traceback.print_exc() + + # Test with specific algorithm classes directly + print("\nDirect Algorithm Class Test:") + + # Test AES-128-CBC + aes128 = AES128CBC(key="test_key_123456") + encrypted_aes128 = aes128.encrypt(test_data) + decrypted_aes128 = aes128.decrypt(encrypted_aes128) + print(f"AES-128-CBC test: {'passed' if decrypted_aes128 == test_data else 'failed'}") + + # Test AES-256-CBC + aes256 = AES256CBC(key="test_key_123456") + encrypted_aes256 = aes256.encrypt(test_data) + decrypted_aes256 = aes256.decrypt(encrypted_aes256) + print(f"AES-256-CBC test: {'passed' if decrypted_aes256 == test_data else 'failed'}") + + # Test SM4-CBC + try: + sm4 = SM4CBC(key="test_key_123456") + encrypted_sm4 = sm4.encrypt(test_data) + decrypted_sm4 = sm4.decrypt(encrypted_sm4) + print(f"SM4-CBC test: {'passed' if decrypted_sm4 == test_data else 'failed'}") + except Exception as e: + print(f"SM4-CBC test failed: {e}") diff --git a/common/settings.py b/common/settings.py index 45dcdb618..a0385e716 100644 --- a/common/settings.py +++ b/common/settings.py @@ -269,7 +269,27 @@ def init_settings(): GCS = get_base_config("gcs", {}) global STORAGE_IMPL - STORAGE_IMPL = StorageFactory.create(Storage[STORAGE_IMPL_TYPE]) + storage_impl = StorageFactory.create(Storage[STORAGE_IMPL_TYPE]) + + # Define crypto settings + crypto_enabled = os.environ.get("RAGFLOW_CRYPTO_ENABLED", "false").lower() == "true" + + # Check if encryption is enabled + if crypto_enabled: + try: + from rag.utils.encrypted_storage import create_encrypted_storage + algorithm = os.environ.get("RAGFLOW_CRYPTO_ALGORITHM", "aes-256-cbc") + crypto_key = os.environ.get("RAGFLOW_CRYPTO_KEY") + + STORAGE_IMPL = create_encrypted_storage(storage_impl, + algorithm=algorithm, + key=crypto_key, + encryption_enabled=crypto_enabled) + except Exception as e: + logging.error(f"Failed to initialize encrypted storage: {e}") + STORAGE_IMPL = storage_impl + else: + STORAGE_IMPL = storage_impl global retriever, kg_retriever retriever = search.Dealer(docStoreConn) diff --git a/docker/.env b/docker/.env index 51d2cf73b..03f3e687f 100644 --- a/docker/.env +++ b/docker/.env @@ -240,7 +240,10 @@ MINERU_EXECUTABLE="$HOME/uv_tools/.venv/bin/mineru" # MINERU_DELETE_OUTPUT=0 # keep output directory # MINERU_BACKEND=pipeline # or another backend you prefer - - # pptx support DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1 + +# crypto utils +# RAGFLOW_CRYPTO_ENABLED=true +# RAGFLOW_CRYPTO_ALGORITHM=aes-256-cbc # one of aes-256-cbc, aes-128-cbc, sm4-cbc +# RAGFLOW_CRYPTO_KEY=ragflow-crypto-key diff --git a/pyproject.toml b/pyproject.toml index 137582ef9..53ddaa42f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -154,6 +154,7 @@ dependencies = [ "exceptiongroup>=1.3.0,<2.0.0", "ffmpeg-python>=0.2.0", "imageio-ffmpeg>=0.6.0", + "cryptography==46.0.3", "reportlab>=4.4.1", "jinja2>=3.1.0", "boxsdk>=10.1.0", diff --git a/rag/utils/encrypted_storage.py b/rag/utils/encrypted_storage.py new file mode 100644 index 000000000..19e199f4e --- /dev/null +++ b/rag/utils/encrypted_storage.py @@ -0,0 +1,266 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +from common.crypto_utils import CryptoUtil +# from common.decorator import singleton + +class EncryptedStorageWrapper: + """Encrypted storage wrapper that wraps existing storage implementations to provide transparent encryption""" + + def __init__(self, storage_impl, algorithm="aes-256-cbc", key=None, iv=None): + """ + Initialize encrypted storage wrapper + + Args: + storage_impl: Original storage implementation instance + algorithm: Encryption algorithm, default is aes-256-cbc + key: Encryption key, uses RAGFLOW_CRYPTO_KEY environment variable if None + iv: Initialization vector, automatically generated if None + """ + self.storage_impl = storage_impl + self.crypto = CryptoUtil(algorithm=algorithm, key=key, iv=iv) + self.encryption_enabled = True + + # Check if storage implementation has required methods + # todo: Consider abstracting a storage base class to ensure these methods exist + required_methods = ["put", "get", "rm", "obj_exist", "health"] + for method in required_methods: + if not hasattr(storage_impl, method): + raise AttributeError(f"Storage implementation missing required method: {method}") + + logging.info(f"EncryptedStorageWrapper initialized with algorithm: {algorithm}") + + def put(self, bucket, fnm, binary, tenant_id=None): + """ + Encrypt and store data + + Args: + bucket: Bucket name + fnm: File name + binary: Original binary data + tenant_id: Tenant ID (optional) + + Returns: + Storage result + """ + if not self.encryption_enabled: + return self.storage_impl.put(bucket, fnm, binary, tenant_id) + + try: + encrypted_binary = self.crypto.encrypt(binary) + + return self.storage_impl.put(bucket, fnm, encrypted_binary, tenant_id) + except Exception as e: + logging.exception(f"Failed to encrypt and store data: {bucket}/{fnm}, error: {str(e)}") + raise + + def get(self, bucket, fnm, tenant_id=None): + """ + Retrieve and decrypt data + + Args: + bucket: Bucket name + fnm: File name + tenant_id: Tenant ID (optional) + + Returns: + Decrypted binary data + """ + try: + # Get encrypted data + encrypted_binary = self.storage_impl.get(bucket, fnm, tenant_id) + + if encrypted_binary is None: + return None + + if not self.encryption_enabled: + return encrypted_binary + + # Decrypt data + decrypted_binary = self.crypto.decrypt(encrypted_binary) + return decrypted_binary + + except Exception as e: + logging.exception(f"Failed to get and decrypt data: {bucket}/{fnm}, error: {str(e)}") + raise + + def rm(self, bucket, fnm, tenant_id=None): + """ + Delete data (same as original storage implementation, no decryption needed) + + Args: + bucket: Bucket name + fnm: File name + tenant_id: Tenant ID (optional) + + Returns: + Deletion result + """ + return self.storage_impl.rm(bucket, fnm, tenant_id) + + def obj_exist(self, bucket, fnm, tenant_id=None): + """ + Check if object exists (same as original storage implementation, no decryption needed) + + Args: + bucket: Bucket name + fnm: File name + tenant_id: Tenant ID (optional) + + Returns: + Whether the object exists + """ + return self.storage_impl.obj_exist(bucket, fnm, tenant_id) + + def health(self): + """ + Health check (uses the original storage implementation's method) + + Returns: + Health check result + """ + return self.storage_impl.health() + + def bucket_exists(self, bucket): + """ + Check if bucket exists (if the original storage implementation has this method) + + Args: + bucket: Bucket name + + Returns: + Whether the bucket exists + """ + if hasattr(self.storage_impl, "bucket_exists"): + return self.storage_impl.bucket_exists(bucket) + return False + + def get_presigned_url(self, bucket, fnm, expires, tenant_id=None): + """ + Get presigned URL (if the original storage implementation has this method) + + Args: + bucket: Bucket name + fnm: File name + expires: Expiration time + tenant_id: Tenant ID (optional) + + Returns: + Presigned URL + """ + if hasattr(self.storage_impl, "get_presigned_url"): + return self.storage_impl.get_presigned_url(bucket, fnm, expires, tenant_id) + return None + + def scan(self, bucket, fnm, tenant_id=None): + """ + Scan objects (if the original storage implementation has this method) + + Args: + bucket: Bucket name + fnm: File name prefix + tenant_id: Tenant ID (optional) + + Returns: + Scan results + """ + if hasattr(self.storage_impl, "scan"): + return self.storage_impl.scan(bucket, fnm, tenant_id) + return None + + def copy(self, src_bucket, src_path, dest_bucket, dest_path): + """ + Copy object (if the original storage implementation has this method) + + Args: + src_bucket: Source bucket name + src_path: Source file path + dest_bucket: Destination bucket name + dest_path: Destination file path + + Returns: + Copy result + """ + if hasattr(self.storage_impl, "copy"): + return self.storage_impl.copy(src_bucket, src_path, dest_bucket, dest_path) + return False + + def move(self, src_bucket, src_path, dest_bucket, dest_path): + """ + Move object (if the original storage implementation has this method) + + Args: + src_bucket: Source bucket name + src_path: Source file path + dest_bucket: Destination bucket name + dest_path: Destination file path + + Returns: + Move result + """ + if hasattr(self.storage_impl, "move"): + return self.storage_impl.move(src_bucket, src_path, dest_bucket, dest_path) + return False + + def remove_bucket(self, bucket): + """ + Remove bucket (if the original storage implementation has this method) + + Args: + bucket: Bucket name + + Returns: + Remove result + """ + if hasattr(self.storage_impl, "remove_bucket"): + return self.storage_impl.remove_bucket(bucket) + return False + + def enable_encryption(self): + """Enable encryption""" + self.encryption_enabled = True + logging.info("Encryption enabled") + + def disable_encryption(self): + """Disable encryption""" + self.encryption_enabled = False + logging.info("Encryption disabled") + +# Create singleton wrapper function +def create_encrypted_storage(storage_impl, algorithm=None, key=None, encryption_enabled=True): + """ + Create singleton instance of encrypted storage wrapper + + Args: + storage_impl: Original storage implementation instance + algorithm: Encryption algorithm, uses environment variable RAGFLOW_CRYPTO_ALGORITHM or default if None + key: Encryption key, uses environment variable RAGFLOW_CRYPTO_KEY if None + encryption_enabled: Whether to enable encryption functionality + + Returns: + Encrypted storage wrapper instance + """ + wrapper = EncryptedStorageWrapper(storage_impl, algorithm=algorithm, key=key) + + wrapper.encryption_enabled = encryption_enabled + + if encryption_enabled: + logging.info("Encryption enabled in storage wrapper") + else: + logging.info("Encryption disabled in storage wrapper") + + return wrapper diff --git a/uv.lock b/uv.lock index ef3429cd4..257149b89 100644 --- a/uv.lock +++ b/uv.lock @@ -6136,6 +6136,7 @@ dependencies = [ { name = "cn2an" }, { name = "cohere" }, { name = "crawl4ai" }, + { name = "cryptography" }, { name = "dashscope" }, { name = "datrie" }, { name = "debugpy" }, @@ -6309,6 +6310,7 @@ requires-dist = [ { name = "cn2an", specifier = "==0.5.22" }, { name = "cohere", specifier = "==5.6.2" }, { name = "crawl4ai", specifier = ">=0.4.0,<1.0.0" }, + { name = "cryptography", specifier = "==46.0.3" }, { name = "dashscope", specifier = "==1.20.11" }, { name = "datrie", specifier = ">=0.8.3,<0.9.0" }, { name = "debugpy", specifier = ">=1.8.13" },