mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Fix:output_structure in agent (#10907)
### What problem does this PR solve? change: output_structure in agent ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -161,7 +161,32 @@ class Graph:
|
|||||||
cpn = self.get_component(cpn_id)
|
cpn = self.get_component(cpn_id)
|
||||||
if not cpn:
|
if not cpn:
|
||||||
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
|
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
|
||||||
return cpn["obj"].output(var_nm)
|
parts = var_nm.split(".", 1)
|
||||||
|
root_key = parts[0]
|
||||||
|
rest = parts[1] if len(parts) > 1 else ""
|
||||||
|
root_val = cpn["obj"].output(root_key)
|
||||||
|
|
||||||
|
if not rest:
|
||||||
|
return root_val
|
||||||
|
return self.get_variable_param_value(root_val,rest)
|
||||||
|
|
||||||
|
def get_variable_param_value(self, obj: Any, path: str) -> Any:
|
||||||
|
cur = obj
|
||||||
|
if not path:
|
||||||
|
return cur
|
||||||
|
for key in path.split('.'):
|
||||||
|
if cur is None:
|
||||||
|
return None
|
||||||
|
if isinstance(cur, str):
|
||||||
|
try:
|
||||||
|
cur = json.loads(cur)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
if isinstance(cur, dict):
|
||||||
|
cur = cur.get(key)
|
||||||
|
else:
|
||||||
|
cur = getattr(cur, key, None)
|
||||||
|
return cur
|
||||||
|
|
||||||
|
|
||||||
class Canvas(Graph):
|
class Canvas(Graph):
|
||||||
|
|||||||
@ -158,7 +158,12 @@ class Agent(LLM, ToolBase):
|
|||||||
|
|
||||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||||
ex = self.exception_handler()
|
ex = self.exception_handler()
|
||||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure and not (ex and ex["goto"]):
|
output_structure=None
|
||||||
|
try:
|
||||||
|
output_structure=self._param.outputs['structured']
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not output_structure and not (ex and ex["goto"]):
|
||||||
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt))
|
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@ -26,7 +26,7 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from agent.component.base import ComponentBase, ComponentParamBase
|
from agent.component.base import ComponentBase, ComponentParamBase
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from rag.prompts.generator import tool_call_summary, message_fit_in, citation_prompt
|
from rag.prompts.generator import tool_call_summary, message_fit_in, citation_prompt, structured_output_prompt
|
||||||
|
|
||||||
|
|
||||||
class LLMParam(ComponentParamBase):
|
class LLMParam(ComponentParamBase):
|
||||||
@ -214,10 +214,14 @@ class LLM(ComponentBase):
|
|||||||
|
|
||||||
prompt, msg, _ = self._prepare_prompt_variables()
|
prompt, msg, _ = self._prepare_prompt_variables()
|
||||||
error: str = ""
|
error: str = ""
|
||||||
|
output_structure=None
|
||||||
if self._param.output_structure:
|
try:
|
||||||
prompt += "\nThe output MUST follow this JSON format:\n"+json.dumps(self._param.output_structure, ensure_ascii=False, indent=2)
|
output_structure=self._param.outputs['structured']
|
||||||
prompt += "\nRedundant information is FORBIDDEN."
|
except Exception:
|
||||||
|
pass
|
||||||
|
if output_structure:
|
||||||
|
schema=json.dumps(output_structure, ensure_ascii=False, indent=2)
|
||||||
|
prompt += structured_output_prompt(schema)
|
||||||
for _ in range(self._param.max_retries+1):
|
for _ in range(self._param.max_retries+1):
|
||||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||||
error = ""
|
error = ""
|
||||||
@ -228,7 +232,7 @@ class LLM(ComponentBase):
|
|||||||
error = ans
|
error = ans
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
self.set_output("structured_content", json_repair.loads(clean_formated_answer(ans)))
|
self.set_output("structured", json_repair.loads(clean_formated_answer(ans)))
|
||||||
return
|
return
|
||||||
except Exception:
|
except Exception:
|
||||||
msg.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
|
msg.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
|
||||||
@ -239,7 +243,7 @@ class LLM(ComponentBase):
|
|||||||
|
|
||||||
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
|
||||||
ex = self.exception_handler()
|
ex = self.exception_handler()
|
||||||
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure and not (ex and ex["goto"]):
|
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not output_structure and not (ex and ex["goto"]):
|
||||||
self.set_output("content", partial(self._stream_output, prompt, msg))
|
self.set_output("content", partial(self._stream_output, prompt, msg))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@ -146,6 +146,7 @@ KEYWORD_PROMPT_TEMPLATE = load_prompt("keyword_prompt")
|
|||||||
QUESTION_PROMPT_TEMPLATE = load_prompt("question_prompt")
|
QUESTION_PROMPT_TEMPLATE = load_prompt("question_prompt")
|
||||||
VISION_LLM_DESCRIBE_PROMPT = load_prompt("vision_llm_describe_prompt")
|
VISION_LLM_DESCRIBE_PROMPT = load_prompt("vision_llm_describe_prompt")
|
||||||
VISION_LLM_FIGURE_DESCRIBE_PROMPT = load_prompt("vision_llm_figure_describe_prompt")
|
VISION_LLM_FIGURE_DESCRIBE_PROMPT = load_prompt("vision_llm_figure_describe_prompt")
|
||||||
|
STRUCTURED_OUTPUT_PROMPT = load_prompt("structured_output_prompt")
|
||||||
|
|
||||||
ANALYZE_TASK_SYSTEM = load_prompt("analyze_task_system")
|
ANALYZE_TASK_SYSTEM = load_prompt("analyze_task_system")
|
||||||
ANALYZE_TASK_USER = load_prompt("analyze_task_user")
|
ANALYZE_TASK_USER = load_prompt("analyze_task_user")
|
||||||
@ -403,6 +404,11 @@ def form_message(system_prompt, user_prompt):
|
|||||||
return [{"role": "system", "content": system_prompt},{"role": "user", "content": user_prompt}]
|
return [{"role": "system", "content": system_prompt},{"role": "user", "content": user_prompt}]
|
||||||
|
|
||||||
|
|
||||||
|
def structured_output_prompt(schema=None) -> str:
|
||||||
|
template = PROMPT_JINJA_ENV.from_string(STRUCTURED_OUTPUT_PROMPT)
|
||||||
|
return template.render(schema=schema)
|
||||||
|
|
||||||
|
|
||||||
def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str:
|
def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str:
|
||||||
template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY)
|
template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY)
|
||||||
system_prompt = template.render(name=name,
|
system_prompt = template.render(name=name,
|
||||||
|
|||||||
16
rag/prompts/structured_output_prompt.md
Normal file
16
rag/prompts/structured_output_prompt.md
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
You’re a helpful AI assistant. You could answer questions and output in JSON format.
|
||||||
|
constraints:
|
||||||
|
- You must output in JSON format.
|
||||||
|
- Do not output boolean value, use string type instead.
|
||||||
|
- Do not output integer or float value, use number type instead.
|
||||||
|
eg:
|
||||||
|
Here is the JSON schema:
|
||||||
|
{"properties": {"age": {"type": "number","description": ""},"name": {"type": "string","description": ""}},"required": ["age","name"],"type": "Object Array String Number Boolean","value": ""}
|
||||||
|
|
||||||
|
Here is the user's question:
|
||||||
|
My name is John Doe and I am 30 years old.
|
||||||
|
|
||||||
|
output:
|
||||||
|
{"name": "John Doe", "age": 30}
|
||||||
|
Here is the JSON schema:
|
||||||
|
{{ schema }}
|
||||||
Reference in New Issue
Block a user