diff --git a/agent/canvas.py b/agent/canvas.py index a1c31c17a..f0691145d 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -161,7 +161,32 @@ class Graph: cpn = self.get_component(cpn_id) if not cpn: raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'") - return cpn["obj"].output(var_nm) + parts = var_nm.split(".", 1) + root_key = parts[0] + rest = parts[1] if len(parts) > 1 else "" + root_val = cpn["obj"].output(root_key) + + if not rest: + return root_val + return self.get_variable_param_value(root_val,rest) + + def get_variable_param_value(self, obj: Any, path: str) -> Any: + cur = obj + if not path: + return cur + for key in path.split('.'): + if cur is None: + return None + if isinstance(cur, str): + try: + cur = json.loads(cur) + except Exception: + return None + if isinstance(cur, dict): + cur = cur.get(key) + else: + cur = getattr(cur, key, None) + return cur class Canvas(Graph): diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index 32458fc85..b3e2df6f5 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -158,7 +158,12 @@ class Agent(LLM, ToolBase): downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else [] ex = self.exception_handler() - if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure and not (ex and ex["goto"]): + output_structure=None + try: + output_structure=self._param.outputs['structured'] + except Exception: + pass + if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not output_structure and not (ex and ex["goto"]): self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt)) return diff --git a/agent/component/llm.py b/agent/component/llm.py index 1e6c35c27..61e52b6fa 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -26,7 +26,7 @@ from api.db.services.llm_service import LLMBundle from api.db.services.tenant_llm_service import TenantLLMService from agent.component.base import ComponentBase, ComponentParamBase from api.utils.api_utils import timeout -from rag.prompts.generator import tool_call_summary, message_fit_in, citation_prompt +from rag.prompts.generator import tool_call_summary, message_fit_in, citation_prompt, structured_output_prompt class LLMParam(ComponentParamBase): @@ -214,10 +214,14 @@ class LLM(ComponentBase): prompt, msg, _ = self._prepare_prompt_variables() error: str = "" - - if self._param.output_structure: - prompt += "\nThe output MUST follow this JSON format:\n"+json.dumps(self._param.output_structure, ensure_ascii=False, indent=2) - prompt += "\nRedundant information is FORBIDDEN." + output_structure=None + try: + output_structure=self._param.outputs['structured'] + except Exception: + pass + if output_structure: + schema=json.dumps(output_structure, ensure_ascii=False, indent=2) + prompt += structured_output_prompt(schema) for _ in range(self._param.max_retries+1): _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) error = "" @@ -228,7 +232,7 @@ class LLM(ComponentBase): error = ans continue try: - self.set_output("structured_content", json_repair.loads(clean_formated_answer(ans))) + self.set_output("structured", json_repair.loads(clean_formated_answer(ans))) return except Exception: msg.append({"role": "user", "content": "The answer can't not be parsed as JSON"}) @@ -239,7 +243,7 @@ class LLM(ComponentBase): downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else [] ex = self.exception_handler() - if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure and not (ex and ex["goto"]): + if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not output_structure and not (ex and ex["goto"]): self.set_output("content", partial(self._stream_output, prompt, msg)) return diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index 78f6cbad9..8cb69dcef 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -146,6 +146,7 @@ KEYWORD_PROMPT_TEMPLATE = load_prompt("keyword_prompt") QUESTION_PROMPT_TEMPLATE = load_prompt("question_prompt") VISION_LLM_DESCRIBE_PROMPT = load_prompt("vision_llm_describe_prompt") VISION_LLM_FIGURE_DESCRIBE_PROMPT = load_prompt("vision_llm_figure_describe_prompt") +STRUCTURED_OUTPUT_PROMPT = load_prompt("structured_output_prompt") ANALYZE_TASK_SYSTEM = load_prompt("analyze_task_system") ANALYZE_TASK_USER = load_prompt("analyze_task_user") @@ -403,6 +404,11 @@ def form_message(system_prompt, user_prompt): return [{"role": "system", "content": system_prompt},{"role": "user", "content": user_prompt}] +def structured_output_prompt(schema=None) -> str: + template = PROMPT_JINJA_ENV.from_string(STRUCTURED_OUTPUT_PROMPT) + return template.render(schema=schema) + + def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str: template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY) system_prompt = template.render(name=name, diff --git a/rag/prompts/structured_output_prompt.md b/rag/prompts/structured_output_prompt.md new file mode 100644 index 000000000..a64301115 --- /dev/null +++ b/rag/prompts/structured_output_prompt.md @@ -0,0 +1,16 @@ +You’re a helpful AI assistant. You could answer questions and output in JSON format. +constraints: + - You must output in JSON format. + - Do not output boolean value, use string type instead. + - Do not output integer or float value, use number type instead. +eg: + Here is the JSON schema: + {"properties": {"age": {"type": "number","description": ""},"name": {"type": "string","description": ""}},"required": ["age","name"],"type": "Object Array String Number Boolean","value": ""} + + Here is the user's question: + My name is John Doe and I am 30 years old. + + output: + {"name": "John Doe", "age": 30} +Here is the JSON schema: + {{ schema }} \ No newline at end of file