Fix: escape multi-steps issues. (#11016)

### What problem does this PR solve?


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Kevin Hu
2025-11-05 14:51:00 +08:00
committed by GitHub
parent 1a9215bc6f
commit b86e07088b
7 changed files with 36 additions and 7 deletions

View File

@ -281,12 +281,21 @@ class Canvas(Graph):
def _run_batch(f, t):
with ThreadPoolExecutor(max_workers=5) as executor:
thr = []
for i in range(f, t):
i = f
while i < t:
cpn = self.get_component_obj(self.path[i])
if cpn.component_name.lower() in ["begin", "userfillup"]:
thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
i += 1
else:
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
for _, ele in cpn.get_input_elements().items():
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i]:
self.path.pop(i)
t -= 1
break
else:
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
i += 1
for t in thr:
t.result()
@ -316,6 +325,7 @@ class Canvas(Graph):
"thoughts": self.get_component_thoughts(self.path[i])
})
_run_batch(idx, to)
to = len(self.path)
# post processing of components invocation
for i in range(idx, to):
cpn = self.get_component(self.path[i])

View File

@ -16,6 +16,13 @@
from abc import ABC
from agent.component.base import ComponentBase, ComponentParamBase
"""
class VariableModel(BaseModel):
data_type: Annotated[Literal["string", "number", "Object", "Boolean", "Array<string>", "Array<number>", "Array<object>", "Array<boolean>"], Field(default="Array<string>")]
input_mode: Annotated[Literal["constant", "variable"], Field(default="constant")]
value: Annotated[Any, Field(default=None)]
model_config = ConfigDict(extra="forbid")
"""
class IterationParam(ComponentParamBase):
"""

View File

@ -216,7 +216,7 @@ class LLM(ComponentBase):
error: str = ""
output_structure=None
try:
output_structure=self._param.outputs['structured']
output_structure = None#self._param.outputs['structured']
except Exception:
pass
if output_structure:

View File

@ -49,6 +49,9 @@ class MessageParam(ComponentParamBase):
class Message(ComponentBase):
component_name = "Message"
def get_input_elements(self) -> dict[str, Any]:
return self.get_input_elements_from_text("".join(self._param.content))
def get_kwargs(self, script:str, kwargs:dict = {}, delimiter:str=None) -> tuple[str, dict[str, str | list | Any]]:
for k,v in self.get_input_elements_from_text(script).items():
if k in kwargs:

View File

@ -16,6 +16,8 @@
import os
import re
from abc import ABC
from typing import Any
from jinja2 import Template as Jinja2Template
from agent.component.base import ComponentParamBase
from common.connection_utils import timeout
@ -43,6 +45,9 @@ class StringTransformParam(ComponentParamBase):
class StringTransform(Message, ABC):
component_name = "StringTransform"
def get_input_elements(self) -> dict[str, Any]:
return self.get_input_elements_from_text(self._param.script)
def get_input_form(self) -> dict[str, dict]:
if self._param.method == "split":
return {

View File

@ -111,12 +111,14 @@ class SyncLogsService(CommonService):
return list(query.dicts())
@classmethod
def start(cls, id):
def start(cls, id, connector_id):
cls.update_by_id(id, {"status": TaskStatus.RUNNING, "time_started": datetime.now().strftime('%Y-%m-%d %H:%M:%S') })
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.RUNNING})
@classmethod
def done(cls, id):
def done(cls, id, connector_id):
cls.update_by_id(id, {"status": TaskStatus.DONE})
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.DONE})
@classmethod
def schedule(cls, connector_id, kb_id, poll_range_start=None, reindex=False, total_docs_indexed=0):
@ -126,6 +128,7 @@ class SyncLogsService(CommonService):
logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.")
return None
reindex = "1" if reindex else "0"
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDUL})
return cls.save(**{
"id": get_uuid(),
"kb_id": kb_id, "status": TaskStatus.SCHEDULE, "connector_id": connector_id,
@ -142,6 +145,7 @@ class SyncLogsService(CommonService):
full_exception_trace=cls.model.full_exception_trace + str(e)
) \
.where(cls.model.id == task.id).execute()
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDUL})
@classmethod
def increase_docs(cls, id, min_update, max_update, doc_num, err_msg="", error_count=0):

View File

@ -51,7 +51,7 @@ class SyncBase:
self.conf = conf
async def __call__(self, task: dict):
SyncLogsService.start(task["id"])
SyncLogsService.start(task["id"], task["connector_id"])
try:
async with task_limiter:
with trio.fail_after(task["timeout_secs"]):
@ -113,7 +113,7 @@ class S3(SyncBase):
self.conf["bucket_name"],
begin_info
))
SyncLogsService.done(task["id"])
SyncLogsService.done(task["id"], task["connector_id"])
return next_update