Files
ragflow/common/crypto_utils.py
YngvarHuang 81eb03d230 Support uploading encrypted files to object storage (#11837) (#11838)
### What problem does this PR solve?

Support uploading encrypted files to object storage.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: virgilwong <hyhvirgil@gmail.com>
2025-12-15 09:45:18 +08:00

375 lines
12 KiB
Python

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives import padding
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.primitives import hashes
class BaseCrypto:
"""Base class for cryptographic algorithms"""
# Magic header to identify encrypted data
ENCRYPTED_MAGIC = b'RAGF'
def __init__(self, key, iv=None, block_size=16, key_length=32, iv_length=16):
"""
Initialize cryptographic algorithm
Args:
key: Encryption key
iv: Initialization vector, automatically generated if None
block_size: Block size
key_length: Key length
iv_length: Initialization vector length
"""
self.block_size = block_size
self.key_length = key_length
self.iv_length = iv_length
# Normalize key
self.key = self._normalize_key(key)
self.iv = iv
def _normalize_key(self, key):
"""Normalize key length"""
if isinstance(key, str):
key = key.encode('utf-8')
# Use PBKDF2 for key derivation to ensure correct key length
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=self.key_length,
salt=b"ragflow_crypto_salt", # Fixed salt to ensure consistent key derivation results
iterations=100000,
backend=default_backend()
)
return kdf.derive(key)
def encrypt(self, data):
"""
Encrypt data (template method)
Args:
data: Data to encrypt (bytes)
Returns:
Encrypted data (bytes), format: magic_header + iv + encrypted_data
"""
# Generate random IV
iv = os.urandom(self.iv_length) if not self.iv else self.iv
# Use PKCS7 padding
padder = padding.PKCS7(self.block_size * 8).padder()
padded_data = padder.update(data) + padder.finalize()
# Delegate to subclass for specific encryption
ciphertext = self._encrypt(padded_data, iv)
# Return Magic Header + IV + encrypted data
return self.ENCRYPTED_MAGIC + iv + ciphertext
def decrypt(self, encrypted_data):
"""
Decrypt data (template method)
Args:
encrypted_data: Encrypted data (bytes)
Returns:
Decrypted data (bytes)
"""
# Check if data is encrypted by magic header
if not encrypted_data.startswith(self.ENCRYPTED_MAGIC):
# Not encrypted, return as-is
return encrypted_data
# Remove magic header
encrypted_data = encrypted_data[len(self.ENCRYPTED_MAGIC):]
# Separate IV and encrypted data
iv = encrypted_data[:self.iv_length]
ciphertext = encrypted_data[self.iv_length:]
# Delegate to subclass for specific decryption
padded_data = self._decrypt(ciphertext, iv)
# Remove padding
unpadder = padding.PKCS7(self.block_size * 8).unpadder()
data = unpadder.update(padded_data) + unpadder.finalize()
return data
def _encrypt(self, padded_data, iv):
"""
Encrypt padded data with specific algorithm
Args:
padded_data: Padded data to encrypt
iv: Initialization vector
Returns:
Encrypted data
"""
raise NotImplementedError("_encrypt method must be implemented by subclass")
def _decrypt(self, ciphertext, iv):
"""
Decrypt ciphertext with specific algorithm
Args:
ciphertext: Ciphertext to decrypt
iv: Initialization vector
Returns:
Decrypted padded data
"""
raise NotImplementedError("_decrypt method must be implemented by subclass")
class AESCrypto(BaseCrypto):
"""Base class for AES cryptographic algorithm"""
def __init__(self, key, iv=None, key_length=32):
"""
Initialize AES cryptographic algorithm
Args:
key: Encryption key
iv: Initialization vector, automatically generated if None
key_length: Key length (16 for AES-128, 32 for AES-256)
"""
super().__init__(key, iv, block_size=16, key_length=key_length, iv_length=16)
def _encrypt(self, padded_data, iv):
"""AES encryption implementation"""
# Create encryptor
cipher = Cipher(
algorithms.AES(self.key),
modes.CBC(iv),
backend=default_backend()
)
encryptor = cipher.encryptor()
# Encrypt data
return encryptor.update(padded_data) + encryptor.finalize()
def _decrypt(self, ciphertext, iv):
"""AES decryption implementation"""
# Create decryptor
cipher = Cipher(
algorithms.AES(self.key),
modes.CBC(iv),
backend=default_backend()
)
decryptor = cipher.decryptor()
# Decrypt data
return decryptor.update(ciphertext) + decryptor.finalize()
class AES128CBC(AESCrypto):
"""AES-128-CBC cryptographic algorithm"""
def __init__(self, key, iv=None):
"""
Initialize AES-128-CBC cryptographic algorithm
Args:
key: Encryption key
iv: Initialization vector, automatically generated if None
"""
super().__init__(key, iv, key_length=16)
class AES256CBC(AESCrypto):
"""AES-256-CBC cryptographic algorithm"""
def __init__(self, key, iv=None):
"""
Initialize AES-256-CBC cryptographic algorithm
Args:
key: Encryption key
iv: Initialization vector, automatically generated if None
"""
super().__init__(key, iv, key_length=32)
class SM4CBC(BaseCrypto):
"""SM4-CBC cryptographic algorithm using cryptography library for better performance"""
def __init__(self, key, iv=None):
"""
Initialize SM4-CBC cryptographic algorithm
Args:
key: Encryption key
iv: Initialization vector, automatically generated if None
"""
super().__init__(key, iv, block_size=16, key_length=16, iv_length=16)
def _encrypt(self, padded_data, iv):
"""SM4 encryption implementation using cryptography library"""
# Create encryptor
cipher = Cipher(
algorithms.SM4(self.key),
modes.CBC(iv),
backend=default_backend()
)
encryptor = cipher.encryptor()
# Encrypt data
return encryptor.update(padded_data) + encryptor.finalize()
def _decrypt(self, ciphertext, iv):
"""SM4 decryption implementation using cryptography library"""
# Create decryptor
cipher = Cipher(
algorithms.SM4(self.key),
modes.CBC(iv),
backend=default_backend()
)
decryptor = cipher.decryptor()
# Decrypt data
return decryptor.update(ciphertext) + decryptor.finalize()
class CryptoUtil:
"""Cryptographic utility class, using factory pattern to create cryptographic algorithm instances"""
# Supported cryptographic algorithms mapping
SUPPORTED_ALGORITHMS = {
"aes-128-cbc": AES128CBC,
"aes-256-cbc": AES256CBC,
"sm4-cbc": SM4CBC
}
def __init__(self, algorithm="aes-256-cbc", key=None, iv=None):
"""
Initialize cryptographic utility
Args:
algorithm: Cryptographic algorithm, default is aes-256-cbc
key: Encryption key, uses RAGFLOW_CRYPTO_KEY environment variable if None
iv: Initialization vector, automatically generated if None
"""
if algorithm not in self.SUPPORTED_ALGORITHMS:
raise ValueError(f"Unsupported algorithm: {algorithm}")
if not key:
raise ValueError("Encryption key not provided and RAGFLOW_CRYPTO_KEY environment variable not set")
# Create cryptographic algorithm instance
self.algorithm_name = algorithm
self.crypto = self.SUPPORTED_ALGORITHMS[algorithm](key=key, iv=iv)
def encrypt(self, data):
"""
Encrypt data
Args:
data: Data to encrypt (bytes)
Returns:
Encrypted data (bytes)
"""
# import time
# start_time = time.time()
encrypted = self.crypto.encrypt(data)
# end_time = time.time()
# logging.info(f"Encryption completed, data length: {len(data)} bytes, time: {(end_time - start_time)*1000:.2f} ms")
return encrypted
def decrypt(self, encrypted_data):
"""
Decrypt data
Args:
encrypted_data: Encrypted data (bytes)
Returns:
Decrypted data (bytes)
"""
# import time
# start_time = time.time()
decrypted = self.crypto.decrypt(encrypted_data)
# end_time = time.time()
# logging.info(f"Decryption completed, data length: {len(encrypted_data)} bytes, time: {(end_time - start_time)*1000:.2f} ms")
return decrypted
# Test code
if __name__ == "__main__":
# Test AES encryption
crypto = CryptoUtil(algorithm="aes-256-cbc", key="test_key_123456")
test_data = b"Hello, RAGFlow! This is a test for encryption."
encrypted = crypto.encrypt(test_data)
decrypted = crypto.decrypt(encrypted)
print("AES Test:")
print(f"Original: {test_data}")
print(f"Encrypted: {encrypted}")
print(f"Decrypted: {decrypted}")
print(f"Success: {test_data == decrypted}")
print()
# Test SM4 encryption
try:
crypto_sm4 = CryptoUtil(algorithm="sm4-cbc", key="test_key_123456")
encrypted_sm4 = crypto_sm4.encrypt(test_data)
decrypted_sm4 = crypto_sm4.decrypt(encrypted_sm4)
print("SM4 Test:")
print(f"Original: {test_data}")
print(f"Encrypted: {encrypted_sm4}")
print(f"Decrypted: {decrypted_sm4}")
print(f"Success: {test_data == decrypted_sm4}")
except Exception as e:
print(f"SM4 Test Failed: {e}")
import traceback
traceback.print_exc()
# Test with specific algorithm classes directly
print("\nDirect Algorithm Class Test:")
# Test AES-128-CBC
aes128 = AES128CBC(key="test_key_123456")
encrypted_aes128 = aes128.encrypt(test_data)
decrypted_aes128 = aes128.decrypt(encrypted_aes128)
print(f"AES-128-CBC test: {'passed' if decrypted_aes128 == test_data else 'failed'}")
# Test AES-256-CBC
aes256 = AES256CBC(key="test_key_123456")
encrypted_aes256 = aes256.encrypt(test_data)
decrypted_aes256 = aes256.decrypt(encrypted_aes256)
print(f"AES-256-CBC test: {'passed' if decrypted_aes256 == test_data else 'failed'}")
# Test SM4-CBC
try:
sm4 = SM4CBC(key="test_key_123456")
encrypted_sm4 = sm4.encrypt(test_data)
decrypted_sm4 = sm4.decrypt(encrypted_sm4)
print(f"SM4-CBC test: {'passed' if decrypted_sm4 == test_data else 'failed'}")
except Exception as e:
print(f"SM4-CBC test failed: {e}")