mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-19 03:56:42 +08:00
### What problem does this PR solve? Support uploading encrypted files to object storage. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: virgilwong <hyhvirgil@gmail.com>
375 lines
12 KiB
Python
375 lines
12 KiB
Python
#
|
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import os
|
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
|
from cryptography.hazmat.primitives import padding
|
|
from cryptography.hazmat.backends import default_backend
|
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
|
from cryptography.hazmat.primitives import hashes
|
|
|
|
|
|
class BaseCrypto:
|
|
"""Base class for cryptographic algorithms"""
|
|
|
|
# Magic header to identify encrypted data
|
|
ENCRYPTED_MAGIC = b'RAGF'
|
|
|
|
def __init__(self, key, iv=None, block_size=16, key_length=32, iv_length=16):
|
|
"""
|
|
Initialize cryptographic algorithm
|
|
|
|
Args:
|
|
key: Encryption key
|
|
iv: Initialization vector, automatically generated if None
|
|
block_size: Block size
|
|
key_length: Key length
|
|
iv_length: Initialization vector length
|
|
"""
|
|
self.block_size = block_size
|
|
self.key_length = key_length
|
|
self.iv_length = iv_length
|
|
|
|
# Normalize key
|
|
self.key = self._normalize_key(key)
|
|
self.iv = iv
|
|
|
|
def _normalize_key(self, key):
|
|
"""Normalize key length"""
|
|
if isinstance(key, str):
|
|
key = key.encode('utf-8')
|
|
|
|
# Use PBKDF2 for key derivation to ensure correct key length
|
|
kdf = PBKDF2HMAC(
|
|
algorithm=hashes.SHA256(),
|
|
length=self.key_length,
|
|
salt=b"ragflow_crypto_salt", # Fixed salt to ensure consistent key derivation results
|
|
iterations=100000,
|
|
backend=default_backend()
|
|
)
|
|
|
|
return kdf.derive(key)
|
|
|
|
def encrypt(self, data):
|
|
"""
|
|
Encrypt data (template method)
|
|
|
|
Args:
|
|
data: Data to encrypt (bytes)
|
|
|
|
Returns:
|
|
Encrypted data (bytes), format: magic_header + iv + encrypted_data
|
|
"""
|
|
# Generate random IV
|
|
iv = os.urandom(self.iv_length) if not self.iv else self.iv
|
|
|
|
# Use PKCS7 padding
|
|
padder = padding.PKCS7(self.block_size * 8).padder()
|
|
padded_data = padder.update(data) + padder.finalize()
|
|
|
|
# Delegate to subclass for specific encryption
|
|
ciphertext = self._encrypt(padded_data, iv)
|
|
|
|
# Return Magic Header + IV + encrypted data
|
|
return self.ENCRYPTED_MAGIC + iv + ciphertext
|
|
|
|
def decrypt(self, encrypted_data):
|
|
"""
|
|
Decrypt data (template method)
|
|
|
|
Args:
|
|
encrypted_data: Encrypted data (bytes)
|
|
|
|
Returns:
|
|
Decrypted data (bytes)
|
|
"""
|
|
# Check if data is encrypted by magic header
|
|
if not encrypted_data.startswith(self.ENCRYPTED_MAGIC):
|
|
# Not encrypted, return as-is
|
|
return encrypted_data
|
|
|
|
# Remove magic header
|
|
encrypted_data = encrypted_data[len(self.ENCRYPTED_MAGIC):]
|
|
|
|
# Separate IV and encrypted data
|
|
iv = encrypted_data[:self.iv_length]
|
|
ciphertext = encrypted_data[self.iv_length:]
|
|
|
|
# Delegate to subclass for specific decryption
|
|
padded_data = self._decrypt(ciphertext, iv)
|
|
|
|
# Remove padding
|
|
unpadder = padding.PKCS7(self.block_size * 8).unpadder()
|
|
data = unpadder.update(padded_data) + unpadder.finalize()
|
|
|
|
return data
|
|
|
|
def _encrypt(self, padded_data, iv):
|
|
"""
|
|
Encrypt padded data with specific algorithm
|
|
|
|
Args:
|
|
padded_data: Padded data to encrypt
|
|
iv: Initialization vector
|
|
|
|
Returns:
|
|
Encrypted data
|
|
"""
|
|
raise NotImplementedError("_encrypt method must be implemented by subclass")
|
|
|
|
def _decrypt(self, ciphertext, iv):
|
|
"""
|
|
Decrypt ciphertext with specific algorithm
|
|
|
|
Args:
|
|
ciphertext: Ciphertext to decrypt
|
|
iv: Initialization vector
|
|
|
|
Returns:
|
|
Decrypted padded data
|
|
"""
|
|
raise NotImplementedError("_decrypt method must be implemented by subclass")
|
|
|
|
|
|
class AESCrypto(BaseCrypto):
|
|
"""Base class for AES cryptographic algorithm"""
|
|
|
|
def __init__(self, key, iv=None, key_length=32):
|
|
"""
|
|
Initialize AES cryptographic algorithm
|
|
|
|
Args:
|
|
key: Encryption key
|
|
iv: Initialization vector, automatically generated if None
|
|
key_length: Key length (16 for AES-128, 32 for AES-256)
|
|
"""
|
|
super().__init__(key, iv, block_size=16, key_length=key_length, iv_length=16)
|
|
|
|
def _encrypt(self, padded_data, iv):
|
|
"""AES encryption implementation"""
|
|
# Create encryptor
|
|
cipher = Cipher(
|
|
algorithms.AES(self.key),
|
|
modes.CBC(iv),
|
|
backend=default_backend()
|
|
)
|
|
encryptor = cipher.encryptor()
|
|
|
|
# Encrypt data
|
|
return encryptor.update(padded_data) + encryptor.finalize()
|
|
|
|
def _decrypt(self, ciphertext, iv):
|
|
"""AES decryption implementation"""
|
|
# Create decryptor
|
|
cipher = Cipher(
|
|
algorithms.AES(self.key),
|
|
modes.CBC(iv),
|
|
backend=default_backend()
|
|
)
|
|
decryptor = cipher.decryptor()
|
|
|
|
# Decrypt data
|
|
return decryptor.update(ciphertext) + decryptor.finalize()
|
|
|
|
|
|
class AES128CBC(AESCrypto):
|
|
"""AES-128-CBC cryptographic algorithm"""
|
|
|
|
def __init__(self, key, iv=None):
|
|
"""
|
|
Initialize AES-128-CBC cryptographic algorithm
|
|
|
|
Args:
|
|
key: Encryption key
|
|
iv: Initialization vector, automatically generated if None
|
|
"""
|
|
super().__init__(key, iv, key_length=16)
|
|
|
|
|
|
class AES256CBC(AESCrypto):
|
|
"""AES-256-CBC cryptographic algorithm"""
|
|
|
|
def __init__(self, key, iv=None):
|
|
"""
|
|
Initialize AES-256-CBC cryptographic algorithm
|
|
|
|
Args:
|
|
key: Encryption key
|
|
iv: Initialization vector, automatically generated if None
|
|
"""
|
|
super().__init__(key, iv, key_length=32)
|
|
|
|
|
|
class SM4CBC(BaseCrypto):
|
|
"""SM4-CBC cryptographic algorithm using cryptography library for better performance"""
|
|
|
|
def __init__(self, key, iv=None):
|
|
"""
|
|
Initialize SM4-CBC cryptographic algorithm
|
|
|
|
Args:
|
|
key: Encryption key
|
|
iv: Initialization vector, automatically generated if None
|
|
"""
|
|
super().__init__(key, iv, block_size=16, key_length=16, iv_length=16)
|
|
|
|
def _encrypt(self, padded_data, iv):
|
|
"""SM4 encryption implementation using cryptography library"""
|
|
# Create encryptor
|
|
cipher = Cipher(
|
|
algorithms.SM4(self.key),
|
|
modes.CBC(iv),
|
|
backend=default_backend()
|
|
)
|
|
encryptor = cipher.encryptor()
|
|
|
|
# Encrypt data
|
|
return encryptor.update(padded_data) + encryptor.finalize()
|
|
|
|
def _decrypt(self, ciphertext, iv):
|
|
"""SM4 decryption implementation using cryptography library"""
|
|
# Create decryptor
|
|
cipher = Cipher(
|
|
algorithms.SM4(self.key),
|
|
modes.CBC(iv),
|
|
backend=default_backend()
|
|
)
|
|
decryptor = cipher.decryptor()
|
|
|
|
# Decrypt data
|
|
return decryptor.update(ciphertext) + decryptor.finalize()
|
|
|
|
|
|
class CryptoUtil:
|
|
"""Cryptographic utility class, using factory pattern to create cryptographic algorithm instances"""
|
|
|
|
# Supported cryptographic algorithms mapping
|
|
SUPPORTED_ALGORITHMS = {
|
|
"aes-128-cbc": AES128CBC,
|
|
"aes-256-cbc": AES256CBC,
|
|
"sm4-cbc": SM4CBC
|
|
}
|
|
|
|
def __init__(self, algorithm="aes-256-cbc", key=None, iv=None):
|
|
"""
|
|
Initialize cryptographic utility
|
|
|
|
Args:
|
|
algorithm: Cryptographic algorithm, default is aes-256-cbc
|
|
key: Encryption key, uses RAGFLOW_CRYPTO_KEY environment variable if None
|
|
iv: Initialization vector, automatically generated if None
|
|
"""
|
|
if algorithm not in self.SUPPORTED_ALGORITHMS:
|
|
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
|
|
|
if not key:
|
|
raise ValueError("Encryption key not provided and RAGFLOW_CRYPTO_KEY environment variable not set")
|
|
|
|
# Create cryptographic algorithm instance
|
|
self.algorithm_name = algorithm
|
|
self.crypto = self.SUPPORTED_ALGORITHMS[algorithm](key=key, iv=iv)
|
|
|
|
def encrypt(self, data):
|
|
"""
|
|
Encrypt data
|
|
|
|
Args:
|
|
data: Data to encrypt (bytes)
|
|
|
|
Returns:
|
|
Encrypted data (bytes)
|
|
"""
|
|
# import time
|
|
# start_time = time.time()
|
|
encrypted = self.crypto.encrypt(data)
|
|
# end_time = time.time()
|
|
# logging.info(f"Encryption completed, data length: {len(data)} bytes, time: {(end_time - start_time)*1000:.2f} ms")
|
|
return encrypted
|
|
|
|
def decrypt(self, encrypted_data):
|
|
"""
|
|
Decrypt data
|
|
|
|
Args:
|
|
encrypted_data: Encrypted data (bytes)
|
|
|
|
Returns:
|
|
Decrypted data (bytes)
|
|
"""
|
|
# import time
|
|
# start_time = time.time()
|
|
decrypted = self.crypto.decrypt(encrypted_data)
|
|
# end_time = time.time()
|
|
# logging.info(f"Decryption completed, data length: {len(encrypted_data)} bytes, time: {(end_time - start_time)*1000:.2f} ms")
|
|
return decrypted
|
|
|
|
|
|
# Test code
|
|
if __name__ == "__main__":
|
|
# Test AES encryption
|
|
crypto = CryptoUtil(algorithm="aes-256-cbc", key="test_key_123456")
|
|
test_data = b"Hello, RAGFlow! This is a test for encryption."
|
|
|
|
encrypted = crypto.encrypt(test_data)
|
|
decrypted = crypto.decrypt(encrypted)
|
|
|
|
print("AES Test:")
|
|
print(f"Original: {test_data}")
|
|
print(f"Encrypted: {encrypted}")
|
|
print(f"Decrypted: {decrypted}")
|
|
print(f"Success: {test_data == decrypted}")
|
|
print()
|
|
|
|
# Test SM4 encryption
|
|
try:
|
|
crypto_sm4 = CryptoUtil(algorithm="sm4-cbc", key="test_key_123456")
|
|
encrypted_sm4 = crypto_sm4.encrypt(test_data)
|
|
decrypted_sm4 = crypto_sm4.decrypt(encrypted_sm4)
|
|
|
|
print("SM4 Test:")
|
|
print(f"Original: {test_data}")
|
|
print(f"Encrypted: {encrypted_sm4}")
|
|
print(f"Decrypted: {decrypted_sm4}")
|
|
print(f"Success: {test_data == decrypted_sm4}")
|
|
except Exception as e:
|
|
print(f"SM4 Test Failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
# Test with specific algorithm classes directly
|
|
print("\nDirect Algorithm Class Test:")
|
|
|
|
# Test AES-128-CBC
|
|
aes128 = AES128CBC(key="test_key_123456")
|
|
encrypted_aes128 = aes128.encrypt(test_data)
|
|
decrypted_aes128 = aes128.decrypt(encrypted_aes128)
|
|
print(f"AES-128-CBC test: {'passed' if decrypted_aes128 == test_data else 'failed'}")
|
|
|
|
# Test AES-256-CBC
|
|
aes256 = AES256CBC(key="test_key_123456")
|
|
encrypted_aes256 = aes256.encrypt(test_data)
|
|
decrypted_aes256 = aes256.decrypt(encrypted_aes256)
|
|
print(f"AES-256-CBC test: {'passed' if decrypted_aes256 == test_data else 'failed'}")
|
|
|
|
# Test SM4-CBC
|
|
try:
|
|
sm4 = SM4CBC(key="test_key_123456")
|
|
encrypted_sm4 = sm4.encrypt(test_data)
|
|
decrypted_sm4 = sm4.decrypt(encrypted_sm4)
|
|
print(f"SM4-CBC test: {'passed' if decrypted_sm4 == test_data else 'failed'}")
|
|
except Exception as e:
|
|
print(f"SM4-CBC test failed: {e}")
|