diff --git a/api/db/db_models.py b/api/db/db_models.py index ecd65ad1d..d0f08a7a0 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -36,18 +36,7 @@ from api.utils.json_encode import json_dumps, json_loads from api.utils.configs import deserialize_b64, serialize_b64 from common.time_utils import current_timestamp, timestamp_to_date, date_string_to_timestamp - - -def singleton(cls, *args, **kw): - instances = {} - - def _singleton(): - key = str(cls) + str(os.getpid()) - if key not in instances: - instances[key] = cls(*args, **kw) - return instances[key] - - return _singleton +from common.decorator import singleton CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField} diff --git a/common/decorator.py b/common/decorator.py new file mode 100644 index 000000000..f45a41a9d --- /dev/null +++ b/common/decorator.py @@ -0,0 +1,27 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +def singleton(cls, *args, **kw): + instances = {} + + def _singleton(): + key = str(cls) + str(os.getpid()) + if key not in instances: + instances[key] = cls(*args, **kw) + return instances[key] + + return _singleton \ No newline at end of file diff --git a/rag/utils/__init__.py b/rag/utils/__init__.py index 19d952d58..f3a696318 100644 --- a/rag/utils/__init__.py +++ b/rag/utils/__init__.py @@ -15,23 +15,10 @@ # import os - import tiktoken from api.utils.file_utils import get_project_base_directory - -def singleton(cls, *args, **kw): - instances = {} - - def _singleton(): - key = str(cls) + str(os.getpid()) - if key not in instances: - instances[key] = cls(*args, **kw) - return instances[key] - - return _singleton - tiktoken_cache_dir = get_project_base_directory() os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir # encoder = tiktoken.encoding_for_model("gpt-3.5-turbo") diff --git a/rag/utils/azure_sas_conn.py b/rag/utils/azure_sas_conn.py index d1737546b..771d5afdf 100644 --- a/rag/utils/azure_sas_conn.py +++ b/rag/utils/azure_sas_conn.py @@ -19,7 +19,7 @@ import os import time from io import BytesIO from rag import settings -from rag.utils import singleton +from common.decorator import singleton from azure.storage.blob import ContainerClient diff --git a/rag/utils/azure_spn_conn.py b/rag/utils/azure_spn_conn.py index 4b6b861f4..e5a194aa8 100644 --- a/rag/utils/azure_spn_conn.py +++ b/rag/utils/azure_spn_conn.py @@ -18,7 +18,7 @@ import logging import os import time from rag import settings -from rag.utils import singleton +from common.decorator import singleton from azure.identity import ClientSecretCredential, AzureAuthorityHosts from azure.storage.filedatalake import FileSystemClient diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index 1ebdf4fef..e662428ba 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -26,7 +26,7 @@ from elasticsearch_dsl import UpdateByQuery, Q, Search, Index from elastic_transport import ConnectionTimeout from rag import settings from rag.settings import TAG_FLD, PAGERANK_FLD -from rag.utils import singleton +from common.decorator import singleton from api.utils.file_utils import get_project_base_directory from api.utils.common import convert_bytes from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \ diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py index 0d85d25e6..1d6dbd091 100644 --- a/rag/utils/infinity_conn.py +++ b/rag/utils/infinity_conn.py @@ -27,7 +27,7 @@ from infinity.connection_pool import ConnectionPool from infinity.errors import ErrorCode from rag import settings from rag.settings import PAGERANK_FLD, TAG_FLD -from rag.utils import singleton +from common.decorator import singleton import pandas as pd from api.utils.file_utils import get_project_base_directory from rag.nlp import is_english diff --git a/rag/utils/minio_conn.py b/rag/utils/minio_conn.py index ff15c2bbf..1777de0e5 100644 --- a/rag/utils/minio_conn.py +++ b/rag/utils/minio_conn.py @@ -21,7 +21,7 @@ from minio.commonconfig import CopySource from minio.error import S3Error from io import BytesIO from rag import settings -from rag.utils import singleton +from common.decorator import singleton @singleton diff --git a/rag/utils/opendal_conn.py b/rag/utils/opendal_conn.py index 1293516a9..41abdf343 100644 --- a/rag/utils/opendal_conn.py +++ b/rag/utils/opendal_conn.py @@ -4,7 +4,7 @@ import pymysql from urllib.parse import quote_plus from api.utils.configs import get_base_config -from rag.utils import singleton +from common.decorator import singleton CREATE_TABLE_SQL = """ diff --git a/rag/utils/opensearch_conn.py b/rag/utils/opensearch_conn.py index fc7d82cf9..3c2cf376b 100644 --- a/rag/utils/opensearch_conn.py +++ b/rag/utils/opensearch_conn.py @@ -26,7 +26,7 @@ from opensearchpy import UpdateByQuery, Q, Search, Index from opensearchpy import ConnectionTimeout from rag import settings from rag.settings import TAG_FLD, PAGERANK_FLD -from rag.utils import singleton +from common.decorator import singleton from api.utils.file_utils import get_project_base_directory from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \ FusionExpr diff --git a/rag/utils/oss_conn.py b/rag/utils/oss_conn.py index 3e75d70fc..a34bd2323 100644 --- a/rag/utils/oss_conn.py +++ b/rag/utils/oss_conn.py @@ -19,7 +19,7 @@ from botocore.exceptions import ClientError from botocore.config import Config import time from io import BytesIO -from rag.utils import singleton +from common.decorator import singleton from rag import settings diff --git a/rag/utils/redis_conn.py b/rag/utils/redis_conn.py index 63010576e..2b295eacf 100644 --- a/rag/utils/redis_conn.py +++ b/rag/utils/redis_conn.py @@ -20,7 +20,7 @@ import uuid import valkey as redis from rag import settings -from rag.utils import singleton +from common.decorator import singleton from valkey.lock import Lock import trio diff --git a/rag/utils/s3_conn.py b/rag/utils/s3_conn.py index 2a1e08185..190f18b51 100644 --- a/rag/utils/s3_conn.py +++ b/rag/utils/s3_conn.py @@ -20,7 +20,7 @@ from botocore.exceptions import ClientError from botocore.config import Config import time from io import BytesIO -from rag.utils import singleton +from common.decorator import singleton from rag import settings @singleton diff --git a/test/unit_test/common/test_decorator.py b/test/unit_test/common/test_decorator.py new file mode 100644 index 000000000..a4bc4a24e --- /dev/null +++ b/test/unit_test/common/test_decorator.py @@ -0,0 +1,79 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from common.decorator import singleton + + +# Test class for demonstration +@singleton +class TestClass: + def __init__(self): + self.counter = 0 + + def increment(self): + self.counter += 1 + return self.counter + + +# Test cases +class TestSingleton: + + def test_state_persistence(self): + """Test that instance state persists across multiple calls""" + instance1 = TestClass() + instance1.increment() + instance1.increment() + + instance2 = TestClass() + assert instance2.counter == 2 # State should persist + + def test_multiple_calls_consistency(self): + """Test consistency across multiple calls""" + instances = [TestClass() for _ in range(5)] + + # All references should point to the same object + first_instance = instances[0] + for instance in instances: + assert instance is first_instance + + def test_instance_methods_work(self): + """Test that instance methods work correctly""" + instance = TestClass() + + # Test method calls + result1 = instance.increment() + result2 = instance.increment() + + assert result1 == 3 + assert result2 == 4 + assert instance.counter == 4 + + +# Test decorator itself +def test_singleton_decorator_returns_callable(): + """Test that the decorator returns a callable""" + + class PlainClass: + pass + + decorated_class = singleton(PlainClass) + + # Should return a function + assert callable(decorated_class) + + # Calling should return an instance of PlainClass + instance = decorated_class() + assert isinstance(instance, PlainClass)