mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Fix: escape multi-steps issues. (#11016)
### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -281,12 +281,21 @@ class Canvas(Graph):
|
||||
def _run_batch(f, t):
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
thr = []
|
||||
for i in range(f, t):
|
||||
i = f
|
||||
while i < t:
|
||||
cpn = self.get_component_obj(self.path[i])
|
||||
if cpn.component_name.lower() in ["begin", "userfillup"]:
|
||||
thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
|
||||
i += 1
|
||||
else:
|
||||
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
|
||||
for _, ele in cpn.get_input_elements().items():
|
||||
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i]:
|
||||
self.path.pop(i)
|
||||
t -= 1
|
||||
break
|
||||
else:
|
||||
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
|
||||
i += 1
|
||||
for t in thr:
|
||||
t.result()
|
||||
|
||||
@ -316,6 +325,7 @@ class Canvas(Graph):
|
||||
"thoughts": self.get_component_thoughts(self.path[i])
|
||||
})
|
||||
_run_batch(idx, to)
|
||||
to = len(self.path)
|
||||
# post processing of components invocation
|
||||
for i in range(idx, to):
|
||||
cpn = self.get_component(self.path[i])
|
||||
|
||||
@ -16,6 +16,13 @@
|
||||
from abc import ABC
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
"""
|
||||
class VariableModel(BaseModel):
|
||||
data_type: Annotated[Literal["string", "number", "Object", "Boolean", "Array<string>", "Array<number>", "Array<object>", "Array<boolean>"], Field(default="Array<string>")]
|
||||
input_mode: Annotated[Literal["constant", "variable"], Field(default="constant")]
|
||||
value: Annotated[Any, Field(default=None)]
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
"""
|
||||
|
||||
class IterationParam(ComponentParamBase):
|
||||
"""
|
||||
|
||||
@ -216,7 +216,7 @@ class LLM(ComponentBase):
|
||||
error: str = ""
|
||||
output_structure=None
|
||||
try:
|
||||
output_structure=self._param.outputs['structured']
|
||||
output_structure = None#self._param.outputs['structured']
|
||||
except Exception:
|
||||
pass
|
||||
if output_structure:
|
||||
|
||||
@ -49,6 +49,9 @@ class MessageParam(ComponentParamBase):
|
||||
class Message(ComponentBase):
|
||||
component_name = "Message"
|
||||
|
||||
def get_input_elements(self) -> dict[str, Any]:
|
||||
return self.get_input_elements_from_text("".join(self._param.content))
|
||||
|
||||
def get_kwargs(self, script:str, kwargs:dict = {}, delimiter:str=None) -> tuple[str, dict[str, str | list | Any]]:
|
||||
for k,v in self.get_input_elements_from_text(script).items():
|
||||
if k in kwargs:
|
||||
|
||||
@ -16,6 +16,8 @@
|
||||
import os
|
||||
import re
|
||||
from abc import ABC
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import Template as Jinja2Template
|
||||
from agent.component.base import ComponentParamBase
|
||||
from common.connection_utils import timeout
|
||||
@ -43,6 +45,9 @@ class StringTransformParam(ComponentParamBase):
|
||||
class StringTransform(Message, ABC):
|
||||
component_name = "StringTransform"
|
||||
|
||||
def get_input_elements(self) -> dict[str, Any]:
|
||||
return self.get_input_elements_from_text(self._param.script)
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
if self._param.method == "split":
|
||||
return {
|
||||
|
||||
@ -111,12 +111,14 @@ class SyncLogsService(CommonService):
|
||||
return list(query.dicts())
|
||||
|
||||
@classmethod
|
||||
def start(cls, id):
|
||||
def start(cls, id, connector_id):
|
||||
cls.update_by_id(id, {"status": TaskStatus.RUNNING, "time_started": datetime.now().strftime('%Y-%m-%d %H:%M:%S') })
|
||||
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.RUNNING})
|
||||
|
||||
@classmethod
|
||||
def done(cls, id):
|
||||
def done(cls, id, connector_id):
|
||||
cls.update_by_id(id, {"status": TaskStatus.DONE})
|
||||
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.DONE})
|
||||
|
||||
@classmethod
|
||||
def schedule(cls, connector_id, kb_id, poll_range_start=None, reindex=False, total_docs_indexed=0):
|
||||
@ -126,6 +128,7 @@ class SyncLogsService(CommonService):
|
||||
logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.")
|
||||
return None
|
||||
reindex = "1" if reindex else "0"
|
||||
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDUL})
|
||||
return cls.save(**{
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb_id, "status": TaskStatus.SCHEDULE, "connector_id": connector_id,
|
||||
@ -142,6 +145,7 @@ class SyncLogsService(CommonService):
|
||||
full_exception_trace=cls.model.full_exception_trace + str(e)
|
||||
) \
|
||||
.where(cls.model.id == task.id).execute()
|
||||
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDUL})
|
||||
|
||||
@classmethod
|
||||
def increase_docs(cls, id, min_update, max_update, doc_num, err_msg="", error_count=0):
|
||||
|
||||
@ -51,7 +51,7 @@ class SyncBase:
|
||||
self.conf = conf
|
||||
|
||||
async def __call__(self, task: dict):
|
||||
SyncLogsService.start(task["id"])
|
||||
SyncLogsService.start(task["id"], task["connector_id"])
|
||||
try:
|
||||
async with task_limiter:
|
||||
with trio.fail_after(task["timeout_secs"]):
|
||||
@ -113,7 +113,7 @@ class S3(SyncBase):
|
||||
self.conf["bucket_name"],
|
||||
begin_info
|
||||
))
|
||||
SyncLogsService.done(task["id"])
|
||||
SyncLogsService.done(task["id"], task["connector_id"])
|
||||
return next_update
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user