Feat:memory sdk (#12538)

### What problem does this PR solve?

Move memory and message apis to /api, and add sdk support.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Lynn
2026-01-09 17:45:58 +08:00
committed by GitHub
parent 64b1e0b4c3
commit f9d4179bf2
22 changed files with 1475 additions and 61 deletions

View File

@ -0,0 +1,52 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
import random
@pytest.fixture(scope="class")
def add_memory_func(client, request):
def cleanup():
memory_list_res = client.list_memory()
exist_memory_ids = [memory.id for memory in memory_list_res["memory_list"]]
for memory_id in exist_memory_ids:
client.delete_memory(memory_id)
request.addfinalizer(cleanup)
memory_ids = []
for i in range(3):
payload = {
"name": f"test_memory_{i}",
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
res = client.create_memory(**payload)
memory_ids.append(res.id)
request.cls.memory_ids = memory_ids
return memory_ids
@pytest.fixture(scope="class")
def delete_test_memory(client, request):
def cleanup():
memory_list_res = client.list_memory()
exist_memory_ids = [memory.id for memory in memory_list_res["memory_list"]]
for memory_id in exist_memory_ids:
client.delete_memory(memory_id)
request.addfinalizer(cleanup)
return

View File

@ -0,0 +1,108 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import random
import re
import pytest
from configs import INVALID_API_TOKEN, HOST_ADDRESS
from ragflow_sdk import RAGFlow
from hypothesis import example, given, settings
from utils.hypothesis_utils import valid_names
class TestAuthorization:
@pytest.mark.p1
@pytest.mark.parametrize(
"invalid_auth, expected_message",
[
(None, "<Unauthorized '401: Unauthorized'>"),
(INVALID_API_TOKEN, "<Unauthorized '401: Unauthorized'>"),
],
ids=["empty_auth", "invalid_api_token"]
)
def test_auth_invalid(self, invalid_auth, expected_message):
client = RAGFlow(invalid_auth, HOST_ADDRESS)
with pytest.raises(Exception) as exception_info:
client.create_memory(**{"name": "test_memory", "memory_type": ["raw"], "embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW", "llm_id": "glm-4-flash@ZHIPU-AI"})
assert str(exception_info.value) == expected_message, str(exception_info.value)
@pytest.mark.usefixtures("delete_test_memory")
class TestMemoryCreate:
@pytest.mark.p1
@given(name=valid_names())
@example("e" * 128)
@settings(max_examples=20)
def test_name(self, client, name):
payload = {
"name": name,
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
memory = client.create_memory(**payload)
pattern = rf'^{name}|{name}(?:\((\d+)\))?$'
escaped_name = re.escape(memory.name)
assert re.match(pattern, escaped_name), str(memory)
@pytest.mark.p2
@pytest.mark.parametrize(
"name, expected_message",
[
("", "Memory name cannot be empty or whitespace."),
(" ", "Memory name cannot be empty or whitespace."),
("a" * 129, f"Memory name '{'a'*129}' exceeds limit of 128."),
],
ids=["empty_name", "space_name", "too_long_name"],
)
def test_name_invalid(self, client, name, expected_message):
payload = {
"name": name,
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
with pytest.raises(Exception) as exception_info:
client.create_memory(**payload)
assert str(exception_info.value) == expected_message, str(exception_info.value)
@pytest.mark.p2
@given(name=valid_names())
def test_type_invalid(self, client, name):
payload = {
"name": name,
"memory_type": ["something"],
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
with pytest.raises(Exception) as exception_info:
client.create_memory(**payload)
assert str(exception_info.value) == f"Memory type '{ {'something'} }' is not supported.", str(exception_info.value)
@pytest.mark.p3
def test_name_duplicated(self, client):
name = "duplicated_name_test"
payload = {
"name": name,
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
"llm_id": "glm-4-flash@ZHIPU-AI"
}
res1 = client.create_memory(**payload)
assert res1.name == name, str(res1)
res2 = client.create_memory(**payload)
assert res2.name == f"{name}(1)", str(res2)

View File

@ -0,0 +1,116 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
from ragflow_sdk import RAGFlow
from configs import INVALID_API_TOKEN, HOST_ADDRESS
class TestAuthorization:
@pytest.mark.p1
@pytest.mark.parametrize(
"invalid_auth, expected_message",
[
(None, "<Unauthorized '401: Unauthorized'>"),
(INVALID_API_TOKEN, "<Unauthorized '401: Unauthorized'>"),
],
)
def test_auth_invalid(self, invalid_auth, expected_message):
client = RAGFlow(invalid_auth, HOST_ADDRESS)
with pytest.raises(Exception) as exception_info:
client.list_memory()
assert str(exception_info.value) == expected_message, str(exception_info.value)
class TestCapability:
@pytest.mark.p3
def test_capability(self, client):
count = 100
with ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(client.list_memory) for _ in range(count)]
responses = list(as_completed(futures))
assert len(responses) == count, responses
assert all(future.result()["code"] == 0 for future in futures)
@pytest.mark.usefixtures("add_memory_func")
class TestMemoryList:
@pytest.mark.p1
def test_params_unset(self, client):
res = client.list_memory()
assert len(res["memory_list"]) == 3, str(res)
assert res["total_count"] == 3, str(res)
@pytest.mark.p1
def test_params_empty(self, client):
res = client.list_memory(**{})
assert len(res["memory_list"]) == 3, str(res)
assert res["total_count"] == 3, str(res)
@pytest.mark.p1
@pytest.mark.parametrize(
"params, expected_page_size",
[
({"page": 1, "page_size": 10}, 3),
({"page": 2, "page_size": 10}, 0),
({"page": 1, "page_size": 2}, 2),
({"page": 2, "page_size": 2}, 1),
({"page": 5, "page_size": 10}, 0),
],
ids=["normal_first_page", "beyond_max_page", "normal_last_partial_page" , "normal_middle_page",
"full_data_single_page"],
)
def test_page(self, client, params, expected_page_size):
# have added 3 memories in fixture
res = client.list_memory(**params)
assert len(res["memory_list"]) == expected_page_size, str(res)
assert res["total_count"] == 3, str(res)
@pytest.mark.p2
def test_filter_memory_type(self, client):
res = client.list_memory(**{"memory_type": ["semantic"]})
for memory in res["memory_list"]:
assert "semantic" in memory.memory_type, str(memory)
@pytest.mark.p2
def test_filter_multi_memory_type(self, client):
res = client.list_memory(**{"memory_type": ["episodic", "procedural"]})
for memory in res["memory_list"]:
assert "episodic" in memory.memory_type or "procedural" in memory.memory_type, str(memory)
@pytest.mark.p2
def test_filter_storage_type(self, client):
res = client.list_memory(**{"storage_type": "table"})
for memory in res["memory_list"]:
assert memory.storage_type == "table", str(memory)
@pytest.mark.p2
def test_match_keyword(self, client):
res = client.list_memory(**{"keywords": "s"})
for memory in res["memory_list"]:
assert "s" in memory.name, str(memory)
@pytest.mark.p1
def test_get_config(self, client):
memory_list = client.list_memory()
assert len(memory_list["memory_list"]) > 0, str(memory_list)
memory = memory_list["memory_list"][0]
memory_id = memory.id
memory_config = memory.get_config()
assert memory_config.id == memory_id, memory_config
for field in ["name", "avatar", "tenant_id", "owner_name", "memory_type", "storage_type",
"embd_id", "llm_id", "permissions", "description", "memory_size", "forgetting_policy",
"temperature", "system_prompt", "user_prompt"]:
assert hasattr(memory, field), memory_config

View File

@ -0,0 +1,52 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from ragflow_sdk import RAGFlow
from configs import INVALID_API_TOKEN, HOST_ADDRESS
class TestAuthorization:
@pytest.mark.p1
@pytest.mark.parametrize(
"invalid_auth, expected_message",
[
(None, "<Unauthorized '401: Unauthorized'>"),
(INVALID_API_TOKEN, "<Unauthorized '401: Unauthorized'>"),
],
)
def test_auth_invalid(self, invalid_auth, expected_message):
client = RAGFlow(invalid_auth, HOST_ADDRESS)
with pytest.raises(Exception) as exception_info:
client.delete_memory("some_memory_id")
assert str(exception_info.value) == expected_message, str(exception_info.value)
@pytest.mark.usefixtures("add_memory_func")
class TestMemoryDelete:
@pytest.mark.p1
def test_memory_id(self, client):
memory_ids = self.memory_ids
client.delete_memory(memory_ids[0])
res = client.list_memory()
assert res["total_count"] == 2, res
@pytest.mark.p2
def test_id_wrong_uuid(self, client):
with pytest.raises(Exception) as exception_info:
client.delete_memory("d94a8dc02c9711f0930f7fbc369eab6d")
assert exception_info.value, str(exception_info.value)
res = client.list_memory()
assert len(res["memory_list"]) == 2, res

View File

@ -0,0 +1,164 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import random
import pytest
from configs import INVALID_API_TOKEN, HOST_ADDRESS
from ragflow_sdk import RAGFlow, Memory
from hypothesis import HealthCheck, example, given, settings
from utils import encode_avatar
from utils.file_utils import create_image_file
from utils.hypothesis_utils import valid_names
class TestAuthorization:
@pytest.mark.p1
@pytest.mark.parametrize(
"invalid_auth, expected_message",
[
(None, "<Unauthorized '401: Unauthorized'>"),
(INVALID_API_TOKEN, "<Unauthorized '401: Unauthorized'>"),
],
ids=["empty_auth", "invalid_api_token"]
)
def test_auth_invalid(self, invalid_auth, expected_message):
with pytest.raises(Exception) as exception_info:
client = RAGFlow(invalid_auth, HOST_ADDRESS)
memory = Memory(client, {"id": "memory_id"})
memory.update({"name": "New_Name"})
assert str(exception_info.value) == expected_message, str(exception_info.value)
@pytest.mark.usefixtures("add_memory_func")
class TestMemoryUpdate:
@pytest.mark.p1
@given(name=valid_names())
@example("f" * 128)
@settings(max_examples=20, suppress_health_check=[HealthCheck.function_scoped_fixture])
def test_name(self, client, name):
memory_ids = self.memory_ids
update_dict = {"name": name}
memory = Memory(client, {"id": random.choice(memory_ids)})
res = memory.update(update_dict)
assert res.name == name, str(res)
@pytest.mark.p2
@pytest.mark.parametrize(
"name, expected_message",
[
("", "Memory name cannot be empty or whitespace."),
(" ", "Memory name cannot be empty or whitespace."),
("a" * 129, f"Memory name '{'a' * 129}' exceeds limit of 128."),
]
)
def test_name_invalid(self, client, name, expected_message):
memory_ids = self.memory_ids
update_dict = {"name": name}
memory = Memory(client, {"id": random.choice(memory_ids)})
with pytest.raises(Exception) as exception_info:
memory.update(update_dict)
assert str(exception_info.value) == expected_message, str(exception_info.value)
@pytest.mark.p2
def test_duplicate_name(self, client):
memory_ids = self.memory_ids
update_dict = {"name": "Test_Memory"}
memory_0 = Memory(client, {"id": memory_ids[0]})
res_0 = memory_0.update(update_dict)
assert res_0.name == "Test_Memory", str(res_0)
memory_1 = Memory(client, {"id": memory_ids[1]})
res_1 = memory_1.update(update_dict)
assert res_1.name == "Test_Memory(1)", str(res_1)
@pytest.mark.p1
def test_avatar(self, client, tmp_path):
memory_ids = self.memory_ids
fn = create_image_file(tmp_path / "ragflow_test.png")
update_dict = {"avatar": f"data:image/png;base64,{encode_avatar(fn)}"}
memory = Memory(client, {"id": random.choice(memory_ids)})
res = memory.update(update_dict)
assert res.avatar == f"data:image/png;base64,{encode_avatar(fn)}", str(res)
@pytest.mark.p1
def test_description(self, client):
memory_ids = self.memory_ids
description = "This is a test description."
update_dict = {"description": description}
memory = Memory(client, {"id": random.choice(memory_ids)})
res = memory.update(update_dict)
assert res.description == description, str(res)
@pytest.mark.p1
def test_llm(self, client):
memory_ids = self.memory_ids
llm_id = "glm-4@ZHIPU-AI"
update_dict = {"llm_id": llm_id}
memory = Memory(client, {"id": random.choice(memory_ids)})
res = memory.update(update_dict)
assert res.llm_id == llm_id, str(res)
@pytest.mark.p1
@pytest.mark.parametrize(
"permission",
[
"me",
"team"
],
ids=["me", "team"]
)
def test_permission(self, client, permission):
memory_ids = self.memory_ids
update_dict = {"permissions": permission}
memory = Memory(client, {"id": random.choice(memory_ids)})
res = memory.update(update_dict)
assert res.permissions == permission.lower().strip(), str(res)
@pytest.mark.p1
def test_memory_size(self, client):
memory_ids = self.memory_ids
memory_size = 1048576 # 1 MB
update_dict = {"memory_size": memory_size}
memory = Memory(client, {"id": random.choice(memory_ids)})
res = memory.update(update_dict)
assert res.memory_size == memory_size, str(res)
@pytest.mark.p1
def test_temperature(self, client):
memory_ids = self.memory_ids
temperature = 0.7
update_dict = {"temperature": temperature}
memory = Memory(client, {"id": random.choice(memory_ids)})
res = memory.update(update_dict)
assert res.temperature == temperature, str(res)
@pytest.mark.p1
def test_system_prompt(self, client):
memory_ids = self.memory_ids
system_prompt = "This is a system prompt."
update_dict = {"system_prompt": system_prompt}
memory = Memory(client, {"id": random.choice(memory_ids)})
res = memory.update(update_dict)
assert res.system_prompt == system_prompt, str(res)
@pytest.mark.p1
def test_user_prompt(self, client):
memory_ids = self.memory_ids
user_prompt = "This is a user prompt."
update_dict = {"user_prompt": user_prompt}
memory = Memory(client, {"id": random.choice(memory_ids)})
res = memory.update(update_dict)
assert res.user_prompt == user_prompt, res