Feat: user defined prompt. (#9972)

### What problem does this PR solve?


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-09-08 14:05:01 +08:00
committed by GitHub
parent cf18231713
commit e9ee9269f5
11 changed files with 203 additions and 66 deletions

View File

@ -155,12 +155,12 @@ class Agent(LLM, ToolBase):
if not self.tools:
return LLM._invoke(self, **kwargs)
prompt, msg = self._prepare_prompt_variables()
prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
ex = self.exception_handler()
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure and not (ex and ex["goto"]):
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg))
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt))
return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
@ -182,11 +182,11 @@ class Agent(LLM, ToolBase):
self.set_output("use_tools", use_tools)
return ans
def stream_output_with_tools(self, prompt, msg):
def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer_without_toolcall = ""
use_tools = []
for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools):
for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
if delta_ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
@ -209,7 +209,7 @@ class Agent(LLM, ToolBase):
]):
yield delta_ans
def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools):
def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools, user_defined_prompt={}):
token_count = 0
tool_metas = self.tool_meta
hist = deepcopy(history)
@ -230,7 +230,7 @@ class Agent(LLM, ToolBase):
# last_calling,
# last_calling != name
#]):
# self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"])))
# self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"]),user_defined_prompt))
last_calling = name
tool_response = self.toolcall_session.tool_call(name, args)
use_tools.append({
@ -239,7 +239,7 @@ class Agent(LLM, ToolBase):
"results": tool_response
})
# self.callback("add_memory", {}, "...")
#self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response))
#self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response), user_defined_prompt)
return name, tool_response
@ -279,10 +279,10 @@ class Agent(LLM, ToolBase):
hist.append({"role": "user", "content": content})
st = timer()
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas)
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
for _ in range(self._param.max_rounds + 1):
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc)
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
# self.callback("next_step", {}, str(response)[:256]+"...")
token_count += tk
hist.append({"role": "assistant", "content": response})
@ -307,7 +307,7 @@ class Agent(LLM, ToolBase):
thr.append(executor.submit(use_tool, name, args))
st = timer()
reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr])
reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr], user_defined_prompt)
append_user_content(hist, reflection)
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
@ -334,10 +334,10 @@ Respond immediately with your final comprehensive answer.
for txt, tkcnt in complete():
yield txt, tkcnt
def get_useful_memory(self, goal: str, sub_goal:str, topn=3) -> str:
def get_useful_memory(self, goal: str, sub_goal:str, topn=3, user_defined_prompt:dict={}) -> str:
# self.callback("get_useful_memory", {"topn": 3}, "...")
mems = self._canvas.get_memory()
rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems])
rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems], user_defined_prompt)
try:
rank = json_repair.loads(re.sub(r"```.*", "", rank))[:topn]
mems = [mems[r] for r in rank]

View File

@ -145,12 +145,23 @@ class LLM(ComponentBase):
msg.append(deepcopy(p))
sys_prompt = self.string_format(sys_prompt, args)
user_defined_prompt, sys_prompt = self._extract_prompts(sys_prompt)
for m in msg:
m["content"] = self.string_format(m["content"], args)
if self._param.cite and self._canvas.get_reference()["chunks"]:
sys_prompt += citation_prompt()
sys_prompt += citation_prompt(user_defined_prompt)
return sys_prompt, msg
return sys_prompt, msg, user_defined_prompt
def _extract_prompts(self, sys_prompt):
pts = {}
for tag in ["TASK_ANALYSIS", "PLAN_GENERATION", "REFLECTION", "CONTEXT_SUMMARY", "CONTEXT_RANKING", "CITATION_GUIDELINES"]:
r = re.search(rf"<{tag}>(.*?)</{tag}>", sys_prompt, flags=re.DOTALL|re.IGNORECASE)
if not r:
continue
pts[tag.lower()] = r.group(1)
sys_prompt = re.sub(rf"<{tag}>(.*?)</{tag}>", sys_prompt, flags=re.DOTALL|re.IGNORECASE)
return pts, sys_prompt
def _generate(self, msg:list[dict], **kwargs) -> str:
if not self.imgs:
@ -198,7 +209,7 @@ class LLM(ComponentBase):
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
prompt, msg = self._prepare_prompt_variables()
prompt, msg, _ = self._prepare_prompt_variables()
error = ""
if self._param.output_structure:
@ -262,11 +273,11 @@ class LLM(ComponentBase):
answer += ans
self.set_output("content", answer)
def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str):
summ = tool_call_summary(self.chat_mdl, func_name, params, results)
def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)
logging.info(f"[MEMORY]: {summ}")
self._canvas.add_memory(user, assist, summ)
def thoughts(self) -> str:
_, msg = self._prepare_prompt_variables()
_, msg,_ = self._prepare_prompt_variables()
return "⌛Give me a moment—starting from: \n\n" + re.sub(r"(User's query:|[\\]+)", '', msg[-1]['content'], flags=re.DOTALL) + "\n\nIll figure out our best next move."