From b86e07088b01f5c1eb0e5783a8ba9bacfc39b5fa Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Wed, 5 Nov 2025 14:51:00 +0800 Subject: [PATCH] Fix: escape multi-steps issues. (#11016) ### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- agent/canvas.py | 14 ++++++++++++-- agent/component/iteration.py | 7 +++++++ agent/component/llm.py | 2 +- agent/component/message.py | 3 +++ agent/component/string_transform.py | 5 +++++ api/db/services/connector_service.py | 8 ++++++-- rag/svr/sync_data_source.py | 4 ++-- 7 files changed, 36 insertions(+), 7 deletions(-) diff --git a/agent/canvas.py b/agent/canvas.py index f0691145d..348717bda 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -281,12 +281,21 @@ class Canvas(Graph): def _run_batch(f, t): with ThreadPoolExecutor(max_workers=5) as executor: thr = [] - for i in range(f, t): + i = f + while i < t: cpn = self.get_component_obj(self.path[i]) if cpn.component_name.lower() in ["begin", "userfillup"]: thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {}))) + i += 1 else: - thr.append(executor.submit(cpn.invoke, **cpn.get_input())) + for _, ele in cpn.get_input_elements().items(): + if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i]: + self.path.pop(i) + t -= 1 + break + else: + thr.append(executor.submit(cpn.invoke, **cpn.get_input())) + i += 1 for t in thr: t.result() @@ -316,6 +325,7 @@ class Canvas(Graph): "thoughts": self.get_component_thoughts(self.path[i]) }) _run_batch(idx, to) + to = len(self.path) # post processing of components invocation for i in range(idx, to): cpn = self.get_component(self.path[i]) diff --git a/agent/component/iteration.py b/agent/component/iteration.py index 460969d7e..a6065a281 100644 --- a/agent/component/iteration.py +++ b/agent/component/iteration.py @@ -16,6 +16,13 @@ from abc import ABC from agent.component.base import ComponentBase, ComponentParamBase +""" +class VariableModel(BaseModel): + data_type: Annotated[Literal["string", "number", "Object", "Boolean", "Array", "Array", "Array", "Array"], Field(default="Array")] + input_mode: Annotated[Literal["constant", "variable"], Field(default="constant")] + value: Annotated[Any, Field(default=None)] + model_config = ConfigDict(extra="forbid") +""" class IterationParam(ComponentParamBase): """ diff --git a/agent/component/llm.py b/agent/component/llm.py index f8d9c3cc4..d2ed1514d 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -216,7 +216,7 @@ class LLM(ComponentBase): error: str = "" output_structure=None try: - output_structure=self._param.outputs['structured'] + output_structure = None#self._param.outputs['structured'] except Exception: pass if output_structure: diff --git a/agent/component/message.py b/agent/component/message.py index a91c2522e..6fcedf984 100644 --- a/agent/component/message.py +++ b/agent/component/message.py @@ -49,6 +49,9 @@ class MessageParam(ComponentParamBase): class Message(ComponentBase): component_name = "Message" + def get_input_elements(self) -> dict[str, Any]: + return self.get_input_elements_from_text("".join(self._param.content)) + def get_kwargs(self, script:str, kwargs:dict = {}, delimiter:str=None) -> tuple[str, dict[str, str | list | Any]]: for k,v in self.get_input_elements_from_text(script).items(): if k in kwargs: diff --git a/agent/component/string_transform.py b/agent/component/string_transform.py index 7802075d1..08e44d2e0 100644 --- a/agent/component/string_transform.py +++ b/agent/component/string_transform.py @@ -16,6 +16,8 @@ import os import re from abc import ABC +from typing import Any + from jinja2 import Template as Jinja2Template from agent.component.base import ComponentParamBase from common.connection_utils import timeout @@ -43,6 +45,9 @@ class StringTransformParam(ComponentParamBase): class StringTransform(Message, ABC): component_name = "StringTransform" + def get_input_elements(self) -> dict[str, Any]: + return self.get_input_elements_from_text(self._param.script) + def get_input_form(self) -> dict[str, dict]: if self._param.method == "split": return { diff --git a/api/db/services/connector_service.py b/api/db/services/connector_service.py index cadfc2b77..3ba00596a 100644 --- a/api/db/services/connector_service.py +++ b/api/db/services/connector_service.py @@ -111,12 +111,14 @@ class SyncLogsService(CommonService): return list(query.dicts()) @classmethod - def start(cls, id): + def start(cls, id, connector_id): cls.update_by_id(id, {"status": TaskStatus.RUNNING, "time_started": datetime.now().strftime('%Y-%m-%d %H:%M:%S') }) + ConnectorService.update_by_id(connector_id, {"status": TaskStatus.RUNNING}) @classmethod - def done(cls, id): + def done(cls, id, connector_id): cls.update_by_id(id, {"status": TaskStatus.DONE}) + ConnectorService.update_by_id(connector_id, {"status": TaskStatus.DONE}) @classmethod def schedule(cls, connector_id, kb_id, poll_range_start=None, reindex=False, total_docs_indexed=0): @@ -126,6 +128,7 @@ class SyncLogsService(CommonService): logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.") return None reindex = "1" if reindex else "0" + ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDUL}) return cls.save(**{ "id": get_uuid(), "kb_id": kb_id, "status": TaskStatus.SCHEDULE, "connector_id": connector_id, @@ -142,6 +145,7 @@ class SyncLogsService(CommonService): full_exception_trace=cls.model.full_exception_trace + str(e) ) \ .where(cls.model.id == task.id).execute() + ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDUL}) @classmethod def increase_docs(cls, id, min_update, max_update, doc_num, err_msg="", error_count=0): diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py index b75dc5bf9..181b51286 100644 --- a/rag/svr/sync_data_source.py +++ b/rag/svr/sync_data_source.py @@ -51,7 +51,7 @@ class SyncBase: self.conf = conf async def __call__(self, task: dict): - SyncLogsService.start(task["id"]) + SyncLogsService.start(task["id"], task["connector_id"]) try: async with task_limiter: with trio.fail_after(task["timeout_secs"]): @@ -113,7 +113,7 @@ class S3(SyncBase): self.conf["bucket_name"], begin_info )) - SyncLogsService.done(task["id"]) + SyncLogsService.done(task["id"], task["connector_id"]) return next_update