Feat: message manage (#12196)

### What problem does this PR solve?

Manage message and use in agent.

Issue #4213 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Lynn
2025-12-25 21:18:13 +08:00
committed by GitHub
parent fd53b83190
commit 6e9691a419
54 changed files with 3715 additions and 1039 deletions

View File

@ -278,7 +278,7 @@ class Graph:
class Canvas(Graph):
def __init__(self, dsl: str, tenant_id=None, task_id=None):
def __init__(self, dsl: str, tenant_id=None, task_id=None, canvas_id=None):
self.globals = {
"sys.query": "",
"sys.user_id": tenant_id,
@ -287,6 +287,7 @@ class Canvas(Graph):
}
self.variables = {}
super().__init__(dsl, tenant_id, task_id)
self._id = canvas_id
def load(self):
super().load()
@ -721,6 +722,9 @@ class Canvas(Graph):
def get_mode(self):
return self.components["begin"]["obj"]._param.mode
def get_sys_query(self):
return self.globals.get("sys.query", "")
def set_global_param(self, **kwargs):
self.globals.update(kwargs)

View File

@ -33,6 +33,8 @@ from common.connection_utils import timeout
from common.misc_utils import get_uuid
from common import settings
from api.db.joint_services.memory_message_service import save_to_memory
class MessageParam(ComponentParamBase):
"""
@ -166,6 +168,7 @@ class Message(ComponentBase):
self.set_output("content", all_content)
self._convert_content(all_content)
await self._save_to_memory(all_content)
def _is_jinjia2(self, content:str) -> bool:
patt = [
@ -198,6 +201,7 @@ class Message(ComponentBase):
self.set_output("content", content)
self._convert_content(content)
self._save_to_memory(content)
def thoughts(self) -> str:
return ""
@ -421,3 +425,29 @@ class Message(ComponentBase):
except Exception as e:
logging.error(f"Error converting content to {self._param.output_format}: {e}")
async def _save_to_memory(self, content):
if not self._param.memory_ids:
return True, "No memory selected."
message_dict = {
"user_id": self._canvas._tenant_id,
"agent_id": self._canvas._id,
"session_id": self._canvas.task_id,
"user_input": self._canvas.get_sys_query(),
"agent_response": content
}
res = []
for memory_id in self._param.memory_ids:
success, msg = await save_to_memory(memory_id, message_dict)
res.append({
"memory_id": memory_id,
"success": success,
"msg": msg
})
if all([r["success"] for r in res]):
return True, "Successfully added to memories."
error_text = "Some messages failed to add. " + " ".join([f"Add to memory {r['memory_id']} failed, detail: {r['msg']}" for r in res if not r["success"]])
logging.error(error_text)
return False, error_text

View File

@ -25,10 +25,12 @@ from api.db.services.document_service import DocumentService
from common.metadata_utils import apply_meta_data_filter
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.memory_service import MemoryService
from api.db.joint_services import memory_message_service
from common import settings
from common.connection_utils import timeout
from rag.app.tag import label_question
from rag.prompts.generator import cross_languages, kb_prompt
from rag.prompts.generator import cross_languages, kb_prompt, memory_prompt
class RetrievalParam(ToolParamBase):
@ -57,6 +59,7 @@ class RetrievalParam(ToolParamBase):
self.top_n = 8
self.top_k = 1024
self.kb_ids = []
self.memory_ids = []
self.kb_vars = []
self.rerank_id = ""
self.empty_response = ""
@ -81,15 +84,7 @@ class RetrievalParam(ToolParamBase):
class Retrieval(ToolBase, ABC):
component_name = "Retrieval"
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
async def _invoke_async(self, **kwargs):
if self.check_if_canceled("Retrieval processing"):
return
if not kwargs.get("query"):
self.set_output("formalized_content", self._param.empty_response)
return
async def _retrieve_kb(self, query_text: str):
kb_ids: list[str] = []
for id in self._param.kb_ids:
if id.find("@") < 0:
@ -124,12 +119,12 @@ class Retrieval(ToolBase, ABC):
if self._param.rerank_id:
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
vars = self.get_input_elements_from_text(kwargs["query"])
vars = {k:o["value"] for k,o in vars.items()}
query = self.string_format(kwargs["query"], vars)
vars = self.get_input_elements_from_text(query_text)
vars = {k: o["value"] for k, o in vars.items()}
query = self.string_format(query_text, vars)
doc_ids=[]
if self._param.meta_data_filter!={}:
doc_ids = []
if self._param.meta_data_filter != {}:
metas = DocumentService.get_meta_by_kbs(kb_ids)
def _resolve_manual_filter(flt: dict) -> dict:
@ -198,18 +193,20 @@ class Retrieval(ToolBase, ABC):
if self._param.toc_enhance:
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs],
chat_mdl, self._param.top_n)
if self.check_if_canceled("Retrieval processing"):
return
if cks:
kbinfos["chunks"] = cks
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"],
[kb.tenant_id for kb in kbs])
if self._param.use_kg:
ck = settings.kg_retriever.retrieval(query,
[kb.tenant_id for kb in kbs],
kb_ids,
embd_mdl,
LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT))
[kb.tenant_id for kb in kbs],
kb_ids,
embd_mdl,
LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT))
if self.check_if_canceled("Retrieval processing"):
return
if ck["content_with_weight"]:
@ -218,7 +215,8 @@ class Retrieval(ToolBase, ABC):
kbinfos = {"chunks": [], "doc_aggs": []}
if self._param.use_kg and kbs:
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl,
LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
if self.check_if_canceled("Retrieval processing"):
return
if ck["content_with_weight"]:
@ -248,6 +246,50 @@ class Retrieval(ToolBase, ABC):
return form_cnt
async def _retrieve_memory(self, query_text: str):
memory_ids: list[str] = [memory_id for memory_id in self._param.memory_ids]
memory_list = MemoryService.get_by_ids(memory_ids)
if not memory_list:
raise Exception("No memory is selected.")
embd_names = list({memory.embd_id for memory in memory_list})
assert len(embd_names) == 1, "Memory use different embedding models."
vars = self.get_input_elements_from_text(query_text)
vars = {k: o["value"] for k, o in vars.items()}
query = self.string_format(query_text, vars)
# query message
message_list = memory_message_service.query_message({"memory_id": memory_ids}, {
"query": query,
"similarity_threshold": self._param.similarity_threshold,
"keywords_similarity_weight": self._param.keywords_similarity_weight,
"top_n": self._param.top_n
})
if not message_list:
self.set_output("formalized_content", self._param.empty_response)
return ""
formated_content = "\n".join(memory_prompt(message_list, 200000))
# set formalized_content output
self.set_output("formalized_content", formated_content)
return formated_content
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
async def _invoke_async(self, **kwargs):
if self.check_if_canceled("Retrieval processing"):
return
if not kwargs.get("query"):
self.set_output("formalized_content", self._param.empty_response)
return
if self._param.kb_ids:
return await self._retrieve_kb(kwargs["query"])
elif hasattr(self._param, "memory_ids") and self._param.memory_ids:
return await self._retrieve_memory(kwargs["query"])
else:
self.set_output("formalized_content", self._param.empty_response)
return
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
return asyncio.run(self._invoke_async(**kwargs))