Feat: user defined prompt. (#9972)

### What problem does this PR solve?


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-09-08 14:05:01 +08:00
committed by GitHub
parent cf18231713
commit e9ee9269f5
11 changed files with 203 additions and 66 deletions

View File

@ -157,7 +157,7 @@ ASK_SUMMARY = load_prompt("ask_summary")
PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
def citation_prompt() -> str:
def citation_prompt(user_defined_prompts: dict={}) -> str:
template = PROMPT_JINJA_ENV.from_string(CITATION_PROMPT_TEMPLATE)
return template.render()
@ -339,7 +339,7 @@ def form_history(history, limit=-6):
return context
def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict]):
def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
tools_desc = tool_schema(tools_description)
context = ""
@ -354,7 +354,7 @@ def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict]):
return kwd
def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc):
def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
if not tools_description:
return ""
desc = tool_schema(tools_description)
@ -372,7 +372,7 @@ def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc):
return json_str, tk_cnt
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple]):
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
goal = history[1]["content"]
template = PROMPT_JINJA_ENV.from_string(REFLECT)
@ -398,7 +398,7 @@ def form_message(system_prompt, user_prompt):
return [{"role": "system", "content": system_prompt},{"role": "user", "content": user_prompt}]
def tool_call_summary(chat_mdl, name: str, params: dict, result: str) -> str:
def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str:
template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY)
system_prompt = template.render(name=name,
params=json.dumps(params, ensure_ascii=False, indent=2),
@ -409,7 +409,7 @@ def tool_call_summary(chat_mdl, name: str, params: dict, result: str) -> str:
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str]):
def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
user_prompt = " → rank: "