mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: Support metadata auto filer for Search. (#9524)
### What problem does this PR solve? ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -484,7 +484,7 @@ class Canvas:
|
|||||||
threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
|
threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
|
||||||
return [th.result() for th in threads]
|
return [th.result() for th in threads]
|
||||||
|
|
||||||
def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any):
|
def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None):
|
||||||
agent_ids = agent_id.split("-->")
|
agent_ids = agent_id.split("-->")
|
||||||
agent_name = self.get_component_name(agent_ids[0])
|
agent_name = self.get_component_name(agent_ids[0])
|
||||||
path = agent_name if len(agent_ids) < 2 else agent_name+"-->"+"-->".join(agent_ids[1:])
|
path = agent_name if len(agent_ids) < 2 else agent_name+"-->"+"-->".join(agent_ids[1:])
|
||||||
@ -493,16 +493,16 @@ class Canvas:
|
|||||||
if bin:
|
if bin:
|
||||||
obj = json.loads(bin.encode("utf-8"))
|
obj = json.loads(bin.encode("utf-8"))
|
||||||
if obj[-1]["component_id"] == agent_ids[0]:
|
if obj[-1]["component_id"] == agent_ids[0]:
|
||||||
obj[-1]["trace"].append({"path": path, "tool_name": func_name, "arguments": params, "result": result})
|
obj[-1]["trace"].append({"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time})
|
||||||
else:
|
else:
|
||||||
obj.append({
|
obj.append({
|
||||||
"component_id": agent_ids[0],
|
"component_id": agent_ids[0],
|
||||||
"trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result}]
|
"trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time}]
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
obj = [{
|
obj = [{
|
||||||
"component_id": agent_ids[0],
|
"component_id": agent_ids[0],
|
||||||
"trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result}]
|
"trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result, "elapsed_time": elapsed_time}]
|
||||||
}]
|
}]
|
||||||
REDIS_CONN.set_obj(f"{self.task_id}-{self.message_id}-logs", obj, 60*10)
|
REDIS_CONN.set_obj(f"{self.task_id}-{self.message_id}-logs", obj, 60*10)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from functools import partial
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import json_repair
|
import json_repair
|
||||||
|
from timeit import default_timer as timer
|
||||||
from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, ToolMeta
|
from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, ToolMeta
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
@ -215,8 +215,9 @@ class Agent(LLM, ToolBase):
|
|||||||
hist = deepcopy(history)
|
hist = deepcopy(history)
|
||||||
last_calling = ""
|
last_calling = ""
|
||||||
if len(hist) > 3:
|
if len(hist) > 3:
|
||||||
|
st = timer()
|
||||||
user_request = full_question(messages=history, chat_mdl=self.chat_mdl)
|
user_request = full_question(messages=history, chat_mdl=self.chat_mdl)
|
||||||
self.callback("Multi-turn conversation optimization", {}, user_request)
|
self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
|
||||||
else:
|
else:
|
||||||
user_request = history[-1]["content"]
|
user_request = history[-1]["content"]
|
||||||
|
|
||||||
@ -263,12 +264,13 @@ class Agent(LLM, ToolBase):
|
|||||||
if not need2cite or cited:
|
if not need2cite or cited:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
st = timer()
|
||||||
txt = ""
|
txt = ""
|
||||||
for delta_ans in self._gen_citations(entire_txt):
|
for delta_ans in self._gen_citations(entire_txt):
|
||||||
yield delta_ans, 0
|
yield delta_ans, 0
|
||||||
txt += delta_ans
|
txt += delta_ans
|
||||||
|
|
||||||
self.callback("gen_citations", {}, txt)
|
self.callback("gen_citations", {}, txt, elapsed_time=timer()-st)
|
||||||
|
|
||||||
def append_user_content(hist, content):
|
def append_user_content(hist, content):
|
||||||
if hist[-1]["role"] == "user":
|
if hist[-1]["role"] == "user":
|
||||||
@ -276,8 +278,9 @@ class Agent(LLM, ToolBase):
|
|||||||
else:
|
else:
|
||||||
hist.append({"role": "user", "content": content})
|
hist.append({"role": "user", "content": content})
|
||||||
|
|
||||||
|
st = timer()
|
||||||
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas)
|
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas)
|
||||||
self.callback("analyze_task", {}, task_desc)
|
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
|
||||||
for _ in range(self._param.max_rounds + 1):
|
for _ in range(self._param.max_rounds + 1):
|
||||||
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc)
|
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc)
|
||||||
# self.callback("next_step", {}, str(response)[:256]+"...")
|
# self.callback("next_step", {}, str(response)[:256]+"...")
|
||||||
@ -303,9 +306,10 @@ class Agent(LLM, ToolBase):
|
|||||||
|
|
||||||
thr.append(executor.submit(use_tool, name, args))
|
thr.append(executor.submit(use_tool, name, args))
|
||||||
|
|
||||||
|
st = timer()
|
||||||
reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr])
|
reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr])
|
||||||
append_user_content(hist, reflection)
|
append_user_content(hist, reflection)
|
||||||
self.callback("reflection", {}, str(reflection))
|
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
|
logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
|
||||||
|
|||||||
@ -24,6 +24,7 @@ from api.utils import hash_str2int
|
|||||||
from rag.llm.chat_model import ToolCallSession
|
from rag.llm.chat_model import ToolCallSession
|
||||||
from rag.prompts.prompts import kb_prompt
|
from rag.prompts.prompts import kb_prompt
|
||||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession
|
from rag.utils.mcp_tool_call_conn import MCPToolCallSession
|
||||||
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
|
|
||||||
class ToolParameter(TypedDict):
|
class ToolParameter(TypedDict):
|
||||||
@ -49,12 +50,13 @@ class LLMToolPluginCallSession(ToolCallSession):
|
|||||||
|
|
||||||
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
|
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
|
||||||
assert name in self.tools_map, f"LLM tool {name} does not exist"
|
assert name in self.tools_map, f"LLM tool {name} does not exist"
|
||||||
|
st = timer()
|
||||||
if isinstance(self.tools_map[name], MCPToolCallSession):
|
if isinstance(self.tools_map[name], MCPToolCallSession):
|
||||||
resp = self.tools_map[name].tool_call(name, arguments, 60)
|
resp = self.tools_map[name].tool_call(name, arguments, 60)
|
||||||
else:
|
else:
|
||||||
resp = self.tools_map[name].invoke(**arguments)
|
resp = self.tools_map[name].invoke(**arguments)
|
||||||
|
|
||||||
self.callback(name, arguments, resp)
|
self.callback(name, arguments, resp, elapsed_time=timer()-st)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
def get_tool_obj(self, name):
|
def get_tool_obj(self, name):
|
||||||
|
|||||||
@ -79,6 +79,17 @@ class ExeSQL(ToolBase, ABC):
|
|||||||
|
|
||||||
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
|
|
||||||
|
def convert_decimals(obj):
|
||||||
|
from decimal import Decimal
|
||||||
|
if isinstance(obj, Decimal):
|
||||||
|
return float(obj) # 或 str(obj)
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
return {k: convert_decimals(v) for k, v in obj.items()}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [convert_decimals(item) for item in obj]
|
||||||
|
return obj
|
||||||
|
|
||||||
sql = kwargs.get("sql")
|
sql = kwargs.get("sql")
|
||||||
if not sql:
|
if not sql:
|
||||||
raise Exception("SQL for `ExeSQL` MUST not be empty.")
|
raise Exception("SQL for `ExeSQL` MUST not be empty.")
|
||||||
@ -122,7 +133,11 @@ class ExeSQL(ToolBase, ABC):
|
|||||||
single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.max_records)])
|
single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.max_records)])
|
||||||
single_res.columns = [i[0] for i in cursor.description]
|
single_res.columns = [i[0] for i in cursor.description]
|
||||||
|
|
||||||
sql_res.append(single_res.to_dict(orient='records'))
|
for col in single_res.columns:
|
||||||
|
if pd.api.types.is_datetime64_any_dtype(single_res[col]):
|
||||||
|
single_res[col] = single_res[col].dt.strftime('%Y-%m-%d')
|
||||||
|
|
||||||
|
sql_res.append(convert_decimals(single_res.to_dict(orient='records')))
|
||||||
formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f"))
|
formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f"))
|
||||||
|
|
||||||
self.set_output("json", sql_res)
|
self.set_output("json", sql_res)
|
||||||
|
|||||||
@ -40,7 +40,7 @@ from rag.app.resume import forbidden_select_fields4resume
|
|||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp.search import index_name
|
from rag.nlp.search import index_name
|
||||||
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in
|
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in
|
||||||
from rag.prompts.prompts import gen_meta_filter
|
from rag.prompts.prompts import gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY
|
||||||
from rag.utils import num_tokens_from_string, rmSpace
|
from rag.utils import num_tokens_from_string, rmSpace
|
||||||
from rag.utils.tavily_conn import Tavily
|
from rag.utils.tavily_conn import Tavily
|
||||||
|
|
||||||
@ -723,6 +723,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
|||||||
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
|
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
|
||||||
max_tokens = chat_mdl.max_length
|
max_tokens = chat_mdl.max_length
|
||||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||||
|
|
||||||
kbinfos = retriever.retrieval(
|
kbinfos = retriever.retrieval(
|
||||||
question = question,
|
question = question,
|
||||||
embd_mdl=embd_mdl,
|
embd_mdl=embd_mdl,
|
||||||
@ -740,26 +741,12 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
|||||||
)
|
)
|
||||||
|
|
||||||
knowledges = kb_prompt(kbinfos, max_tokens)
|
knowledges = kb_prompt(kbinfos, max_tokens)
|
||||||
prompt = """
|
sys_prompt = PROMPT_JINJA_ENV.from_string(ASK_SUMMARY).render(knowledge="\n".join(knowledges))
|
||||||
Role: You're a smart assistant. Your name is Miss R.
|
|
||||||
Task: Summarize the information from knowledge bases and answer user's question.
|
|
||||||
Requirements and restriction:
|
|
||||||
- DO NOT make things up, especially for numbers.
|
|
||||||
- If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided.
|
|
||||||
- Answer with markdown format text.
|
|
||||||
- Answer in language of user's question.
|
|
||||||
- DO NOT make things up, especially for numbers.
|
|
||||||
|
|
||||||
### Information from knowledge bases
|
|
||||||
%s
|
|
||||||
|
|
||||||
The above is information from knowledge bases.
|
|
||||||
|
|
||||||
""" % "\n".join(knowledges)
|
|
||||||
msg = [{"role": "user", "content": question}]
|
msg = [{"role": "user", "content": question}]
|
||||||
|
|
||||||
def decorate_answer(answer):
|
def decorate_answer(answer):
|
||||||
nonlocal knowledges, kbinfos, prompt
|
nonlocal knowledges, kbinfos, sys_prompt
|
||||||
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
|
answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3)
|
||||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||||
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||||
@ -777,7 +764,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
|||||||
return {"answer": answer, "reference": refs}
|
return {"answer": answer, "reference": refs}
|
||||||
|
|
||||||
answer = ""
|
answer = ""
|
||||||
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
|
for ans in chat_mdl.chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
|
||||||
answer = ans
|
answer = ans
|
||||||
yield {"answer": answer, "reference": {}}
|
yield {"answer": answer, "reference": {}}
|
||||||
yield decorate_answer(answer)
|
yield decorate_answer(answer)
|
||||||
|
|||||||
@ -612,10 +612,6 @@ def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。
|
|||||||
continue
|
continue
|
||||||
add_chunk(sub_sec, image)
|
add_chunk(sub_sec, image)
|
||||||
|
|
||||||
for img in images:
|
|
||||||
if isinstance(img, Image.Image):
|
|
||||||
img.close()
|
|
||||||
|
|
||||||
return cks, result_images
|
return cks, result_images
|
||||||
|
|
||||||
def docx_question_level(p, bull=-1):
|
def docx_question_level(p, bull=-1):
|
||||||
|
|||||||
14
rag/prompts/ask_summary.md
Normal file
14
rag/prompts/ask_summary.md
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
Role: You're a smart assistant. Your name is Miss R.
|
||||||
|
Task: Summarize the information from knowledge bases and answer user's question.
|
||||||
|
Requirements and restriction:
|
||||||
|
- DO NOT make things up, especially for numbers.
|
||||||
|
- If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided.
|
||||||
|
- Answer with markdown format text.
|
||||||
|
- Answer in language of user's question.
|
||||||
|
- DO NOT make things up, especially for numbers.
|
||||||
|
|
||||||
|
### Information from knowledge bases
|
||||||
|
|
||||||
|
{{ knowledge }}
|
||||||
|
|
||||||
|
The above is information from knowledge bases.
|
||||||
@ -150,6 +150,7 @@ REFLECT = load_prompt("reflect")
|
|||||||
SUMMARY4MEMORY = load_prompt("summary4memory")
|
SUMMARY4MEMORY = load_prompt("summary4memory")
|
||||||
RANK_MEMORY = load_prompt("rank_memory")
|
RANK_MEMORY = load_prompt("rank_memory")
|
||||||
META_FILTER = load_prompt("meta_filter")
|
META_FILTER = load_prompt("meta_filter")
|
||||||
|
ASK_SUMMARY = load_prompt("ask_summary")
|
||||||
|
|
||||||
PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
|
PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user