mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-26 17:16:52 +08:00
Feat: message manage (#12196)
### What problem does this PR solve? Manage message and use in agent. Issue #4213 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -278,7 +278,7 @@ class Graph:
|
||||
|
||||
class Canvas(Graph):
|
||||
|
||||
def __init__(self, dsl: str, tenant_id=None, task_id=None):
|
||||
def __init__(self, dsl: str, tenant_id=None, task_id=None, canvas_id=None):
|
||||
self.globals = {
|
||||
"sys.query": "",
|
||||
"sys.user_id": tenant_id,
|
||||
@ -287,6 +287,7 @@ class Canvas(Graph):
|
||||
}
|
||||
self.variables = {}
|
||||
super().__init__(dsl, tenant_id, task_id)
|
||||
self._id = canvas_id
|
||||
|
||||
def load(self):
|
||||
super().load()
|
||||
@ -721,6 +722,9 @@ class Canvas(Graph):
|
||||
def get_mode(self):
|
||||
return self.components["begin"]["obj"]._param.mode
|
||||
|
||||
def get_sys_query(self):
|
||||
return self.globals.get("sys.query", "")
|
||||
|
||||
def set_global_param(self, **kwargs):
|
||||
self.globals.update(kwargs)
|
||||
|
||||
|
||||
@ -33,6 +33,8 @@ from common.connection_utils import timeout
|
||||
from common.misc_utils import get_uuid
|
||||
from common import settings
|
||||
|
||||
from api.db.joint_services.memory_message_service import save_to_memory
|
||||
|
||||
|
||||
class MessageParam(ComponentParamBase):
|
||||
"""
|
||||
@ -166,6 +168,7 @@ class Message(ComponentBase):
|
||||
|
||||
self.set_output("content", all_content)
|
||||
self._convert_content(all_content)
|
||||
await self._save_to_memory(all_content)
|
||||
|
||||
def _is_jinjia2(self, content:str) -> bool:
|
||||
patt = [
|
||||
@ -198,6 +201,7 @@ class Message(ComponentBase):
|
||||
|
||||
self.set_output("content", content)
|
||||
self._convert_content(content)
|
||||
self._save_to_memory(content)
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return ""
|
||||
@ -421,3 +425,29 @@ class Message(ComponentBase):
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error converting content to {self._param.output_format}: {e}")
|
||||
|
||||
async def _save_to_memory(self, content):
|
||||
if not self._param.memory_ids:
|
||||
return True, "No memory selected."
|
||||
|
||||
message_dict = {
|
||||
"user_id": self._canvas._tenant_id,
|
||||
"agent_id": self._canvas._id,
|
||||
"session_id": self._canvas.task_id,
|
||||
"user_input": self._canvas.get_sys_query(),
|
||||
"agent_response": content
|
||||
}
|
||||
res = []
|
||||
for memory_id in self._param.memory_ids:
|
||||
success, msg = await save_to_memory(memory_id, message_dict)
|
||||
res.append({
|
||||
"memory_id": memory_id,
|
||||
"success": success,
|
||||
"msg": msg
|
||||
})
|
||||
if all([r["success"] for r in res]):
|
||||
return True, "Successfully added to memories."
|
||||
|
||||
error_text = "Some messages failed to add. " + " ".join([f"Add to memory {r['memory_id']} failed, detail: {r['msg']}" for r in res if not r["success"]])
|
||||
logging.error(error_text)
|
||||
return False, error_text
|
||||
|
||||
@ -25,10 +25,12 @@ from api.db.services.document_service import DocumentService
|
||||
from common.metadata_utils import apply_meta_data_filter
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from api.db.joint_services import memory_message_service
|
||||
from common import settings
|
||||
from common.connection_utils import timeout
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.generator import cross_languages, kb_prompt
|
||||
from rag.prompts.generator import cross_languages, kb_prompt, memory_prompt
|
||||
|
||||
|
||||
class RetrievalParam(ToolParamBase):
|
||||
@ -57,6 +59,7 @@ class RetrievalParam(ToolParamBase):
|
||||
self.top_n = 8
|
||||
self.top_k = 1024
|
||||
self.kb_ids = []
|
||||
self.memory_ids = []
|
||||
self.kb_vars = []
|
||||
self.rerank_id = ""
|
||||
self.empty_response = ""
|
||||
@ -81,15 +84,7 @@ class RetrievalParam(ToolParamBase):
|
||||
class Retrieval(ToolBase, ABC):
|
||||
component_name = "Retrieval"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
async def _invoke_async(self, **kwargs):
|
||||
if self.check_if_canceled("Retrieval processing"):
|
||||
return
|
||||
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", self._param.empty_response)
|
||||
return
|
||||
|
||||
async def _retrieve_kb(self, query_text: str):
|
||||
kb_ids: list[str] = []
|
||||
for id in self._param.kb_ids:
|
||||
if id.find("@") < 0:
|
||||
@ -124,12 +119,12 @@ class Retrieval(ToolBase, ABC):
|
||||
if self._param.rerank_id:
|
||||
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
|
||||
|
||||
vars = self.get_input_elements_from_text(kwargs["query"])
|
||||
vars = {k:o["value"] for k,o in vars.items()}
|
||||
query = self.string_format(kwargs["query"], vars)
|
||||
vars = self.get_input_elements_from_text(query_text)
|
||||
vars = {k: o["value"] for k, o in vars.items()}
|
||||
query = self.string_format(query_text, vars)
|
||||
|
||||
doc_ids=[]
|
||||
if self._param.meta_data_filter!={}:
|
||||
doc_ids = []
|
||||
if self._param.meta_data_filter != {}:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
|
||||
def _resolve_manual_filter(flt: dict) -> dict:
|
||||
@ -198,18 +193,20 @@ class Retrieval(ToolBase, ABC):
|
||||
|
||||
if self._param.toc_enhance:
|
||||
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
|
||||
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
|
||||
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs],
|
||||
chat_mdl, self._param.top_n)
|
||||
if self.check_if_canceled("Retrieval processing"):
|
||||
return
|
||||
if cks:
|
||||
kbinfos["chunks"] = cks
|
||||
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
|
||||
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"],
|
||||
[kb.tenant_id for kb in kbs])
|
||||
if self._param.use_kg:
|
||||
ck = settings.kg_retriever.retrieval(query,
|
||||
[kb.tenant_id for kb in kbs],
|
||||
kb_ids,
|
||||
embd_mdl,
|
||||
LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT))
|
||||
[kb.tenant_id for kb in kbs],
|
||||
kb_ids,
|
||||
embd_mdl,
|
||||
LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT))
|
||||
if self.check_if_canceled("Retrieval processing"):
|
||||
return
|
||||
if ck["content_with_weight"]:
|
||||
@ -218,7 +215,8 @@ class Retrieval(ToolBase, ABC):
|
||||
kbinfos = {"chunks": [], "doc_aggs": []}
|
||||
|
||||
if self._param.use_kg and kbs:
|
||||
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
|
||||
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl,
|
||||
LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
|
||||
if self.check_if_canceled("Retrieval processing"):
|
||||
return
|
||||
if ck["content_with_weight"]:
|
||||
@ -248,6 +246,50 @@ class Retrieval(ToolBase, ABC):
|
||||
|
||||
return form_cnt
|
||||
|
||||
async def _retrieve_memory(self, query_text: str):
|
||||
memory_ids: list[str] = [memory_id for memory_id in self._param.memory_ids]
|
||||
memory_list = MemoryService.get_by_ids(memory_ids)
|
||||
if not memory_list:
|
||||
raise Exception("No memory is selected.")
|
||||
|
||||
embd_names = list({memory.embd_id for memory in memory_list})
|
||||
assert len(embd_names) == 1, "Memory use different embedding models."
|
||||
|
||||
vars = self.get_input_elements_from_text(query_text)
|
||||
vars = {k: o["value"] for k, o in vars.items()}
|
||||
query = self.string_format(query_text, vars)
|
||||
# query message
|
||||
message_list = memory_message_service.query_message({"memory_id": memory_ids}, {
|
||||
"query": query,
|
||||
"similarity_threshold": self._param.similarity_threshold,
|
||||
"keywords_similarity_weight": self._param.keywords_similarity_weight,
|
||||
"top_n": self._param.top_n
|
||||
})
|
||||
if not message_list:
|
||||
self.set_output("formalized_content", self._param.empty_response)
|
||||
return ""
|
||||
formated_content = "\n".join(memory_prompt(message_list, 200000))
|
||||
# set formalized_content output
|
||||
self.set_output("formalized_content", formated_content)
|
||||
|
||||
return formated_content
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
async def _invoke_async(self, **kwargs):
|
||||
if self.check_if_canceled("Retrieval processing"):
|
||||
return
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", self._param.empty_response)
|
||||
return
|
||||
|
||||
if self._param.kb_ids:
|
||||
return await self._retrieve_kb(kwargs["query"])
|
||||
elif hasattr(self._param, "memory_ids") and self._param.memory_ids:
|
||||
return await self._retrieve_memory(kwargs["query"])
|
||||
else:
|
||||
self.set_output("formalized_content", self._param.empty_response)
|
||||
return
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
return asyncio.run(self._invoke_async(**kwargs))
|
||||
|
||||
Reference in New Issue
Block a user