mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: user defined prompt. (#9972)
### What problem does this PR solve? ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -155,12 +155,12 @@ class Agent(LLM, ToolBase):
|
||||
if not self.tools:
|
||||
return LLM._invoke(self, **kwargs)
|
||||
|
||||
prompt, msg = self._prepare_prompt_variables()
|
||||
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
|
||||
|
||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||
ex = self.exception_handler()
|
||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure and not (ex and ex["goto"]):
|
||||
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg))
|
||||
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt))
|
||||
return
|
||||
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
@ -182,11 +182,11 @@ class Agent(LLM, ToolBase):
|
||||
self.set_output("use_tools", use_tools)
|
||||
return ans
|
||||
|
||||
def stream_output_with_tools(self, prompt, msg):
|
||||
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer_without_toolcall = ""
|
||||
use_tools = []
|
||||
for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools):
|
||||
for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
|
||||
if delta_ans.find("**ERROR**") >= 0:
|
||||
if self.get_exception_default_value():
|
||||
self.set_output("content", self.get_exception_default_value())
|
||||
@ -209,7 +209,7 @@ class Agent(LLM, ToolBase):
|
||||
]):
|
||||
yield delta_ans
|
||||
|
||||
def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools):
|
||||
def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools, user_defined_prompt={}):
|
||||
token_count = 0
|
||||
tool_metas = self.tool_meta
|
||||
hist = deepcopy(history)
|
||||
@ -230,7 +230,7 @@ class Agent(LLM, ToolBase):
|
||||
# last_calling,
|
||||
# last_calling != name
|
||||
#]):
|
||||
# self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"])))
|
||||
# self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"]),user_defined_prompt))
|
||||
last_calling = name
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
use_tools.append({
|
||||
@ -239,7 +239,7 @@ class Agent(LLM, ToolBase):
|
||||
"results": tool_response
|
||||
})
|
||||
# self.callback("add_memory", {}, "...")
|
||||
#self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response))
|
||||
#self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response), user_defined_prompt)
|
||||
|
||||
return name, tool_response
|
||||
|
||||
@ -279,10 +279,10 @@ class Agent(LLM, ToolBase):
|
||||
hist.append({"role": "user", "content": content})
|
||||
|
||||
st = timer()
|
||||
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas)
|
||||
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
|
||||
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
|
||||
for _ in range(self._param.max_rounds + 1):
|
||||
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc)
|
||||
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
|
||||
# self.callback("next_step", {}, str(response)[:256]+"...")
|
||||
token_count += tk
|
||||
hist.append({"role": "assistant", "content": response})
|
||||
@ -307,7 +307,7 @@ class Agent(LLM, ToolBase):
|
||||
thr.append(executor.submit(use_tool, name, args))
|
||||
|
||||
st = timer()
|
||||
reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr])
|
||||
reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr], user_defined_prompt)
|
||||
append_user_content(hist, reflection)
|
||||
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
|
||||
|
||||
@ -334,10 +334,10 @@ Respond immediately with your final comprehensive answer.
|
||||
for txt, tkcnt in complete():
|
||||
yield txt, tkcnt
|
||||
|
||||
def get_useful_memory(self, goal: str, sub_goal:str, topn=3) -> str:
|
||||
def get_useful_memory(self, goal: str, sub_goal:str, topn=3, user_defined_prompt:dict={}) -> str:
|
||||
# self.callback("get_useful_memory", {"topn": 3}, "...")
|
||||
mems = self._canvas.get_memory()
|
||||
rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems])
|
||||
rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems], user_defined_prompt)
|
||||
try:
|
||||
rank = json_repair.loads(re.sub(r"```.*", "", rank))[:topn]
|
||||
mems = [mems[r] for r in rank]
|
||||
|
||||
@ -145,12 +145,23 @@ class LLM(ComponentBase):
|
||||
msg.append(deepcopy(p))
|
||||
|
||||
sys_prompt = self.string_format(sys_prompt, args)
|
||||
user_defined_prompt, sys_prompt = self._extract_prompts(sys_prompt)
|
||||
for m in msg:
|
||||
m["content"] = self.string_format(m["content"], args)
|
||||
if self._param.cite and self._canvas.get_reference()["chunks"]:
|
||||
sys_prompt += citation_prompt()
|
||||
sys_prompt += citation_prompt(user_defined_prompt)
|
||||
|
||||
return sys_prompt, msg
|
||||
return sys_prompt, msg, user_defined_prompt
|
||||
|
||||
def _extract_prompts(self, sys_prompt):
|
||||
pts = {}
|
||||
for tag in ["TASK_ANALYSIS", "PLAN_GENERATION", "REFLECTION", "CONTEXT_SUMMARY", "CONTEXT_RANKING", "CITATION_GUIDELINES"]:
|
||||
r = re.search(rf"<{tag}>(.*?)</{tag}>", sys_prompt, flags=re.DOTALL|re.IGNORECASE)
|
||||
if not r:
|
||||
continue
|
||||
pts[tag.lower()] = r.group(1)
|
||||
sys_prompt = re.sub(rf"<{tag}>(.*?)</{tag}>", sys_prompt, flags=re.DOTALL|re.IGNORECASE)
|
||||
return pts, sys_prompt
|
||||
|
||||
def _generate(self, msg:list[dict], **kwargs) -> str:
|
||||
if not self.imgs:
|
||||
@ -198,7 +209,7 @@ class LLM(ComponentBase):
|
||||
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
|
||||
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
|
||||
|
||||
prompt, msg = self._prepare_prompt_variables()
|
||||
prompt, msg, _ = self._prepare_prompt_variables()
|
||||
error = ""
|
||||
|
||||
if self._param.output_structure:
|
||||
@ -262,11 +273,11 @@ class LLM(ComponentBase):
|
||||
answer += ans
|
||||
self.set_output("content", answer)
|
||||
|
||||
def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str):
|
||||
summ = tool_call_summary(self.chat_mdl, func_name, params, results)
|
||||
def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
|
||||
summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)
|
||||
logging.info(f"[MEMORY]: {summ}")
|
||||
self._canvas.add_memory(user, assist, summ)
|
||||
|
||||
def thoughts(self) -> str:
|
||||
_, msg = self._prepare_prompt_variables()
|
||||
_, msg,_ = self._prepare_prompt_variables()
|
||||
return "⌛Give me a moment—starting from: \n\n" + re.sub(r"(User's query:|[\\]+)", '', msg[-1]['content'], flags=re.DOTALL) + "\n\nI’ll figure out our best next move."
|
||||
Reference in New Issue
Block a user