mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: user defined prompt. (#9972)
### What problem does this PR solve? ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -157,7 +157,7 @@ ASK_SUMMARY = load_prompt("ask_summary")
|
||||
PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
|
||||
|
||||
|
||||
def citation_prompt() -> str:
|
||||
def citation_prompt(user_defined_prompts: dict={}) -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(CITATION_PROMPT_TEMPLATE)
|
||||
return template.render()
|
||||
|
||||
@ -339,7 +339,7 @@ def form_history(history, limit=-6):
|
||||
return context
|
||||
|
||||
|
||||
def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict]):
|
||||
def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
|
||||
tools_desc = tool_schema(tools_description)
|
||||
context = ""
|
||||
|
||||
@ -354,7 +354,7 @@ def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict]):
|
||||
return kwd
|
||||
|
||||
|
||||
def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc):
|
||||
def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
|
||||
if not tools_description:
|
||||
return ""
|
||||
desc = tool_schema(tools_description)
|
||||
@ -372,7 +372,7 @@ def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc):
|
||||
return json_str, tk_cnt
|
||||
|
||||
|
||||
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple]):
|
||||
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
|
||||
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
|
||||
goal = history[1]["content"]
|
||||
template = PROMPT_JINJA_ENV.from_string(REFLECT)
|
||||
@ -398,7 +398,7 @@ def form_message(system_prompt, user_prompt):
|
||||
return [{"role": "system", "content": system_prompt},{"role": "user", "content": user_prompt}]
|
||||
|
||||
|
||||
def tool_call_summary(chat_mdl, name: str, params: dict, result: str) -> str:
|
||||
def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY)
|
||||
system_prompt = template.render(name=name,
|
||||
params=json.dumps(params, ensure_ascii=False, indent=2),
|
||||
@ -409,7 +409,7 @@ def tool_call_summary(chat_mdl, name: str, params: dict, result: str) -> str:
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
|
||||
|
||||
def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str]):
|
||||
def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
|
||||
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
|
||||
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
|
||||
user_prompt = " → rank: "
|
||||
|
||||
Reference in New Issue
Block a user