mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-30 15:16:45 +08:00
Feat:memory sdk (#12538)
### What problem does this PR solve? Move memory and message apis to /api, and add sdk support. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -0,0 +1,52 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import pytest
|
||||
import random
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def add_memory_func(client, request):
|
||||
def cleanup():
|
||||
memory_list_res = client.list_memory()
|
||||
exist_memory_ids = [memory.id for memory in memory_list_res["memory_list"]]
|
||||
for memory_id in exist_memory_ids:
|
||||
client.delete_memory(memory_id)
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
|
||||
memory_ids = []
|
||||
for i in range(3):
|
||||
payload = {
|
||||
"name": f"test_memory_{i}",
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
res = client.create_memory(**payload)
|
||||
memory_ids.append(res.id)
|
||||
request.cls.memory_ids = memory_ids
|
||||
return memory_ids
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def delete_test_memory(client, request):
|
||||
def cleanup():
|
||||
memory_list_res = client.list_memory()
|
||||
exist_memory_ids = [memory.id for memory in memory_list_res["memory_list"]]
|
||||
for memory_id in exist_memory_ids:
|
||||
client.delete_memory(memory_id)
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
return
|
||||
@ -0,0 +1,108 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import random
|
||||
import re
|
||||
|
||||
import pytest
|
||||
from configs import INVALID_API_TOKEN, HOST_ADDRESS
|
||||
from ragflow_sdk import RAGFlow
|
||||
from hypothesis import example, given, settings
|
||||
from utils.hypothesis_utils import valid_names
|
||||
|
||||
|
||||
class TestAuthorization:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_message",
|
||||
[
|
||||
(None, "<Unauthorized '401: Unauthorized'>"),
|
||||
(INVALID_API_TOKEN, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
ids=["empty_auth", "invalid_api_token"]
|
||||
)
|
||||
def test_auth_invalid(self, invalid_auth, expected_message):
|
||||
client = RAGFlow(invalid_auth, HOST_ADDRESS)
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
client.create_memory(**{"name": "test_memory", "memory_type": ["raw"], "embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW", "llm_id": "glm-4-flash@ZHIPU-AI"})
|
||||
assert str(exception_info.value) == expected_message, str(exception_info.value)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("delete_test_memory")
|
||||
class TestMemoryCreate:
|
||||
@pytest.mark.p1
|
||||
@given(name=valid_names())
|
||||
@example("e" * 128)
|
||||
@settings(max_examples=20)
|
||||
def test_name(self, client, name):
|
||||
payload = {
|
||||
"name": name,
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
memory = client.create_memory(**payload)
|
||||
pattern = rf'^{name}|{name}(?:\((\d+)\))?$'
|
||||
escaped_name = re.escape(memory.name)
|
||||
assert re.match(pattern, escaped_name), str(memory)
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"name, expected_message",
|
||||
[
|
||||
("", "Memory name cannot be empty or whitespace."),
|
||||
(" ", "Memory name cannot be empty or whitespace."),
|
||||
("a" * 129, f"Memory name '{'a'*129}' exceeds limit of 128."),
|
||||
],
|
||||
ids=["empty_name", "space_name", "too_long_name"],
|
||||
)
|
||||
def test_name_invalid(self, client, name, expected_message):
|
||||
payload = {
|
||||
"name": name,
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
client.create_memory(**payload)
|
||||
assert str(exception_info.value) == expected_message, str(exception_info.value)
|
||||
|
||||
@pytest.mark.p2
|
||||
@given(name=valid_names())
|
||||
def test_type_invalid(self, client, name):
|
||||
payload = {
|
||||
"name": name,
|
||||
"memory_type": ["something"],
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
client.create_memory(**payload)
|
||||
assert str(exception_info.value) == f"Memory type '{ {'something'} }' is not supported.", str(exception_info.value)
|
||||
|
||||
@pytest.mark.p3
|
||||
def test_name_duplicated(self, client):
|
||||
name = "duplicated_name_test"
|
||||
payload = {
|
||||
"name": name,
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
res1 = client.create_memory(**payload)
|
||||
assert res1.name == name, str(res1)
|
||||
|
||||
res2 = client.create_memory(**payload)
|
||||
assert res2.name == f"{name}(1)", str(res2)
|
||||
@ -0,0 +1,116 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
from ragflow_sdk import RAGFlow
|
||||
from configs import INVALID_API_TOKEN, HOST_ADDRESS
|
||||
|
||||
class TestAuthorization:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_message",
|
||||
[
|
||||
(None, "<Unauthorized '401: Unauthorized'>"),
|
||||
(INVALID_API_TOKEN, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_auth_invalid(self, invalid_auth, expected_message):
|
||||
client = RAGFlow(invalid_auth, HOST_ADDRESS)
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
client.list_memory()
|
||||
assert str(exception_info.value) == expected_message, str(exception_info.value)
|
||||
|
||||
|
||||
class TestCapability:
|
||||
@pytest.mark.p3
|
||||
def test_capability(self, client):
|
||||
count = 100
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(client.list_memory) for _ in range(count)]
|
||||
responses = list(as_completed(futures))
|
||||
assert len(responses) == count, responses
|
||||
assert all(future.result()["code"] == 0 for future in futures)
|
||||
|
||||
@pytest.mark.usefixtures("add_memory_func")
|
||||
class TestMemoryList:
|
||||
@pytest.mark.p1
|
||||
def test_params_unset(self, client):
|
||||
res = client.list_memory()
|
||||
assert len(res["memory_list"]) == 3, str(res)
|
||||
assert res["total_count"] == 3, str(res)
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_params_empty(self, client):
|
||||
res = client.list_memory(**{})
|
||||
assert len(res["memory_list"]) == 3, str(res)
|
||||
assert res["total_count"] == 3, str(res)
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"params, expected_page_size",
|
||||
[
|
||||
({"page": 1, "page_size": 10}, 3),
|
||||
({"page": 2, "page_size": 10}, 0),
|
||||
({"page": 1, "page_size": 2}, 2),
|
||||
({"page": 2, "page_size": 2}, 1),
|
||||
({"page": 5, "page_size": 10}, 0),
|
||||
],
|
||||
ids=["normal_first_page", "beyond_max_page", "normal_last_partial_page" , "normal_middle_page",
|
||||
"full_data_single_page"],
|
||||
)
|
||||
def test_page(self, client, params, expected_page_size):
|
||||
# have added 3 memories in fixture
|
||||
res = client.list_memory(**params)
|
||||
assert len(res["memory_list"]) == expected_page_size, str(res)
|
||||
assert res["total_count"] == 3, str(res)
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_filter_memory_type(self, client):
|
||||
res = client.list_memory(**{"memory_type": ["semantic"]})
|
||||
for memory in res["memory_list"]:
|
||||
assert "semantic" in memory.memory_type, str(memory)
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_filter_multi_memory_type(self, client):
|
||||
res = client.list_memory(**{"memory_type": ["episodic", "procedural"]})
|
||||
for memory in res["memory_list"]:
|
||||
assert "episodic" in memory.memory_type or "procedural" in memory.memory_type, str(memory)
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_filter_storage_type(self, client):
|
||||
res = client.list_memory(**{"storage_type": "table"})
|
||||
for memory in res["memory_list"]:
|
||||
assert memory.storage_type == "table", str(memory)
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_match_keyword(self, client):
|
||||
res = client.list_memory(**{"keywords": "s"})
|
||||
for memory in res["memory_list"]:
|
||||
assert "s" in memory.name, str(memory)
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_get_config(self, client):
|
||||
memory_list = client.list_memory()
|
||||
assert len(memory_list["memory_list"]) > 0, str(memory_list)
|
||||
memory = memory_list["memory_list"][0]
|
||||
memory_id = memory.id
|
||||
memory_config = memory.get_config()
|
||||
assert memory_config.id == memory_id, memory_config
|
||||
for field in ["name", "avatar", "tenant_id", "owner_name", "memory_type", "storage_type",
|
||||
"embd_id", "llm_id", "permissions", "description", "memory_size", "forgetting_policy",
|
||||
"temperature", "system_prompt", "user_prompt"]:
|
||||
assert hasattr(memory, field), memory_config
|
||||
@ -0,0 +1,52 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import pytest
|
||||
from ragflow_sdk import RAGFlow
|
||||
from configs import INVALID_API_TOKEN, HOST_ADDRESS
|
||||
|
||||
class TestAuthorization:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_message",
|
||||
[
|
||||
(None, "<Unauthorized '401: Unauthorized'>"),
|
||||
(INVALID_API_TOKEN, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
)
|
||||
def test_auth_invalid(self, invalid_auth, expected_message):
|
||||
client = RAGFlow(invalid_auth, HOST_ADDRESS)
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
client.delete_memory("some_memory_id")
|
||||
assert str(exception_info.value) == expected_message, str(exception_info.value)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("add_memory_func")
|
||||
class TestMemoryDelete:
|
||||
@pytest.mark.p1
|
||||
def test_memory_id(self, client):
|
||||
memory_ids = self.memory_ids
|
||||
client.delete_memory(memory_ids[0])
|
||||
res = client.list_memory()
|
||||
assert res["total_count"] == 2, res
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_id_wrong_uuid(self, client):
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
client.delete_memory("d94a8dc02c9711f0930f7fbc369eab6d")
|
||||
assert exception_info.value, str(exception_info.value)
|
||||
|
||||
res = client.list_memory()
|
||||
assert len(res["memory_list"]) == 2, res
|
||||
@ -0,0 +1,164 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import random
|
||||
import pytest
|
||||
from configs import INVALID_API_TOKEN, HOST_ADDRESS
|
||||
from ragflow_sdk import RAGFlow, Memory
|
||||
from hypothesis import HealthCheck, example, given, settings
|
||||
from utils import encode_avatar
|
||||
from utils.file_utils import create_image_file
|
||||
from utils.hypothesis_utils import valid_names
|
||||
|
||||
|
||||
class TestAuthorization:
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_auth, expected_message",
|
||||
[
|
||||
(None, "<Unauthorized '401: Unauthorized'>"),
|
||||
(INVALID_API_TOKEN, "<Unauthorized '401: Unauthorized'>"),
|
||||
],
|
||||
ids=["empty_auth", "invalid_api_token"]
|
||||
)
|
||||
def test_auth_invalid(self, invalid_auth, expected_message):
|
||||
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
client = RAGFlow(invalid_auth, HOST_ADDRESS)
|
||||
memory = Memory(client, {"id": "memory_id"})
|
||||
memory.update({"name": "New_Name"})
|
||||
assert str(exception_info.value) == expected_message, str(exception_info.value)
|
||||
|
||||
@pytest.mark.usefixtures("add_memory_func")
|
||||
class TestMemoryUpdate:
|
||||
|
||||
@pytest.mark.p1
|
||||
@given(name=valid_names())
|
||||
@example("f" * 128)
|
||||
@settings(max_examples=20, suppress_health_check=[HealthCheck.function_scoped_fixture])
|
||||
def test_name(self, client, name):
|
||||
memory_ids = self.memory_ids
|
||||
update_dict = {"name": name}
|
||||
memory = Memory(client, {"id": random.choice(memory_ids)})
|
||||
res = memory.update(update_dict)
|
||||
assert res.name == name, str(res)
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
"name, expected_message",
|
||||
[
|
||||
("", "Memory name cannot be empty or whitespace."),
|
||||
(" ", "Memory name cannot be empty or whitespace."),
|
||||
("a" * 129, f"Memory name '{'a' * 129}' exceeds limit of 128."),
|
||||
]
|
||||
)
|
||||
def test_name_invalid(self, client, name, expected_message):
|
||||
memory_ids = self.memory_ids
|
||||
update_dict = {"name": name}
|
||||
memory = Memory(client, {"id": random.choice(memory_ids)})
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
memory.update(update_dict)
|
||||
assert str(exception_info.value) == expected_message, str(exception_info.value)
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_duplicate_name(self, client):
|
||||
memory_ids = self.memory_ids
|
||||
update_dict = {"name": "Test_Memory"}
|
||||
memory_0 = Memory(client, {"id": memory_ids[0]})
|
||||
res_0 = memory_0.update(update_dict)
|
||||
assert res_0.name == "Test_Memory", str(res_0)
|
||||
|
||||
memory_1 = Memory(client, {"id": memory_ids[1]})
|
||||
res_1 = memory_1.update(update_dict)
|
||||
assert res_1.name == "Test_Memory(1)", str(res_1)
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_avatar(self, client, tmp_path):
|
||||
memory_ids = self.memory_ids
|
||||
fn = create_image_file(tmp_path / "ragflow_test.png")
|
||||
update_dict = {"avatar": f"data:image/png;base64,{encode_avatar(fn)}"}
|
||||
memory = Memory(client, {"id": random.choice(memory_ids)})
|
||||
res = memory.update(update_dict)
|
||||
assert res.avatar == f"data:image/png;base64,{encode_avatar(fn)}", str(res)
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_description(self, client):
|
||||
memory_ids = self.memory_ids
|
||||
description = "This is a test description."
|
||||
update_dict = {"description": description}
|
||||
memory = Memory(client, {"id": random.choice(memory_ids)})
|
||||
res = memory.update(update_dict)
|
||||
assert res.description == description, str(res)
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_llm(self, client):
|
||||
memory_ids = self.memory_ids
|
||||
llm_id = "glm-4@ZHIPU-AI"
|
||||
update_dict = {"llm_id": llm_id}
|
||||
memory = Memory(client, {"id": random.choice(memory_ids)})
|
||||
res = memory.update(update_dict)
|
||||
assert res.llm_id == llm_id, str(res)
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.parametrize(
|
||||
"permission",
|
||||
[
|
||||
"me",
|
||||
"team"
|
||||
],
|
||||
ids=["me", "team"]
|
||||
)
|
||||
def test_permission(self, client, permission):
|
||||
memory_ids = self.memory_ids
|
||||
update_dict = {"permissions": permission}
|
||||
memory = Memory(client, {"id": random.choice(memory_ids)})
|
||||
res = memory.update(update_dict)
|
||||
assert res.permissions == permission.lower().strip(), str(res)
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_memory_size(self, client):
|
||||
memory_ids = self.memory_ids
|
||||
memory_size = 1048576 # 1 MB
|
||||
update_dict = {"memory_size": memory_size}
|
||||
memory = Memory(client, {"id": random.choice(memory_ids)})
|
||||
res = memory.update(update_dict)
|
||||
assert res.memory_size == memory_size, str(res)
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_temperature(self, client):
|
||||
memory_ids = self.memory_ids
|
||||
temperature = 0.7
|
||||
update_dict = {"temperature": temperature}
|
||||
memory = Memory(client, {"id": random.choice(memory_ids)})
|
||||
res = memory.update(update_dict)
|
||||
assert res.temperature == temperature, str(res)
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_system_prompt(self, client):
|
||||
memory_ids = self.memory_ids
|
||||
system_prompt = "This is a system prompt."
|
||||
update_dict = {"system_prompt": system_prompt}
|
||||
memory = Memory(client, {"id": random.choice(memory_ids)})
|
||||
res = memory.update(update_dict)
|
||||
assert res.system_prompt == system_prompt, str(res)
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_user_prompt(self, client):
|
||||
memory_ids = self.memory_ids
|
||||
user_prompt = "This is a user prompt."
|
||||
update_dict = {"user_prompt": user_prompt}
|
||||
memory = Memory(client, {"id": random.choice(memory_ids)})
|
||||
res = memory.update(update_dict)
|
||||
assert res.user_prompt == user_prompt, res
|
||||
Reference in New Issue
Block a user