Fix: escape multi-steps issues. (#11016)

### What problem does this PR solve?


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Kevin Hu
2025-11-05 14:51:00 +08:00
committed by GitHub
parent 1a9215bc6f
commit b86e07088b
7 changed files with 36 additions and 7 deletions

View File

@ -281,12 +281,21 @@ class Canvas(Graph):
def _run_batch(f, t): def _run_batch(f, t):
with ThreadPoolExecutor(max_workers=5) as executor: with ThreadPoolExecutor(max_workers=5) as executor:
thr = [] thr = []
for i in range(f, t): i = f
while i < t:
cpn = self.get_component_obj(self.path[i]) cpn = self.get_component_obj(self.path[i])
if cpn.component_name.lower() in ["begin", "userfillup"]: if cpn.component_name.lower() in ["begin", "userfillup"]:
thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {}))) thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
i += 1
else: else:
thr.append(executor.submit(cpn.invoke, **cpn.get_input())) for _, ele in cpn.get_input_elements().items():
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i]:
self.path.pop(i)
t -= 1
break
else:
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
i += 1
for t in thr: for t in thr:
t.result() t.result()
@ -316,6 +325,7 @@ class Canvas(Graph):
"thoughts": self.get_component_thoughts(self.path[i]) "thoughts": self.get_component_thoughts(self.path[i])
}) })
_run_batch(idx, to) _run_batch(idx, to)
to = len(self.path)
# post processing of components invocation # post processing of components invocation
for i in range(idx, to): for i in range(idx, to):
cpn = self.get_component(self.path[i]) cpn = self.get_component(self.path[i])

View File

@ -16,6 +16,13 @@
from abc import ABC from abc import ABC
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
"""
class VariableModel(BaseModel):
data_type: Annotated[Literal["string", "number", "Object", "Boolean", "Array<string>", "Array<number>", "Array<object>", "Array<boolean>"], Field(default="Array<string>")]
input_mode: Annotated[Literal["constant", "variable"], Field(default="constant")]
value: Annotated[Any, Field(default=None)]
model_config = ConfigDict(extra="forbid")
"""
class IterationParam(ComponentParamBase): class IterationParam(ComponentParamBase):
""" """

View File

@ -216,7 +216,7 @@ class LLM(ComponentBase):
error: str = "" error: str = ""
output_structure=None output_structure=None
try: try:
output_structure=self._param.outputs['structured'] output_structure = None#self._param.outputs['structured']
except Exception: except Exception:
pass pass
if output_structure: if output_structure:

View File

@ -49,6 +49,9 @@ class MessageParam(ComponentParamBase):
class Message(ComponentBase): class Message(ComponentBase):
component_name = "Message" component_name = "Message"
def get_input_elements(self) -> dict[str, Any]:
return self.get_input_elements_from_text("".join(self._param.content))
def get_kwargs(self, script:str, kwargs:dict = {}, delimiter:str=None) -> tuple[str, dict[str, str | list | Any]]: def get_kwargs(self, script:str, kwargs:dict = {}, delimiter:str=None) -> tuple[str, dict[str, str | list | Any]]:
for k,v in self.get_input_elements_from_text(script).items(): for k,v in self.get_input_elements_from_text(script).items():
if k in kwargs: if k in kwargs:

View File

@ -16,6 +16,8 @@
import os import os
import re import re
from abc import ABC from abc import ABC
from typing import Any
from jinja2 import Template as Jinja2Template from jinja2 import Template as Jinja2Template
from agent.component.base import ComponentParamBase from agent.component.base import ComponentParamBase
from common.connection_utils import timeout from common.connection_utils import timeout
@ -43,6 +45,9 @@ class StringTransformParam(ComponentParamBase):
class StringTransform(Message, ABC): class StringTransform(Message, ABC):
component_name = "StringTransform" component_name = "StringTransform"
def get_input_elements(self) -> dict[str, Any]:
return self.get_input_elements_from_text(self._param.script)
def get_input_form(self) -> dict[str, dict]: def get_input_form(self) -> dict[str, dict]:
if self._param.method == "split": if self._param.method == "split":
return { return {

View File

@ -111,12 +111,14 @@ class SyncLogsService(CommonService):
return list(query.dicts()) return list(query.dicts())
@classmethod @classmethod
def start(cls, id): def start(cls, id, connector_id):
cls.update_by_id(id, {"status": TaskStatus.RUNNING, "time_started": datetime.now().strftime('%Y-%m-%d %H:%M:%S') }) cls.update_by_id(id, {"status": TaskStatus.RUNNING, "time_started": datetime.now().strftime('%Y-%m-%d %H:%M:%S') })
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.RUNNING})
@classmethod @classmethod
def done(cls, id): def done(cls, id, connector_id):
cls.update_by_id(id, {"status": TaskStatus.DONE}) cls.update_by_id(id, {"status": TaskStatus.DONE})
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.DONE})
@classmethod @classmethod
def schedule(cls, connector_id, kb_id, poll_range_start=None, reindex=False, total_docs_indexed=0): def schedule(cls, connector_id, kb_id, poll_range_start=None, reindex=False, total_docs_indexed=0):
@ -126,6 +128,7 @@ class SyncLogsService(CommonService):
logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.") logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.")
return None return None
reindex = "1" if reindex else "0" reindex = "1" if reindex else "0"
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDUL})
return cls.save(**{ return cls.save(**{
"id": get_uuid(), "id": get_uuid(),
"kb_id": kb_id, "status": TaskStatus.SCHEDULE, "connector_id": connector_id, "kb_id": kb_id, "status": TaskStatus.SCHEDULE, "connector_id": connector_id,
@ -142,6 +145,7 @@ class SyncLogsService(CommonService):
full_exception_trace=cls.model.full_exception_trace + str(e) full_exception_trace=cls.model.full_exception_trace + str(e)
) \ ) \
.where(cls.model.id == task.id).execute() .where(cls.model.id == task.id).execute()
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDUL})
@classmethod @classmethod
def increase_docs(cls, id, min_update, max_update, doc_num, err_msg="", error_count=0): def increase_docs(cls, id, min_update, max_update, doc_num, err_msg="", error_count=0):

View File

@ -51,7 +51,7 @@ class SyncBase:
self.conf = conf self.conf = conf
async def __call__(self, task: dict): async def __call__(self, task: dict):
SyncLogsService.start(task["id"]) SyncLogsService.start(task["id"], task["connector_id"])
try: try:
async with task_limiter: async with task_limiter:
with trio.fail_after(task["timeout_secs"]): with trio.fail_after(task["timeout_secs"]):
@ -113,7 +113,7 @@ class S3(SyncBase):
self.conf["bucket_name"], self.conf["bucket_name"],
begin_info begin_info
)) ))
SyncLogsService.done(task["id"]) SyncLogsService.done(task["id"], task["connector_id"])
return next_update return next_update