mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Fix: escape multi-steps issues. (#11016)
### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -281,12 +281,21 @@ class Canvas(Graph):
|
|||||||
def _run_batch(f, t):
|
def _run_batch(f, t):
|
||||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||||
thr = []
|
thr = []
|
||||||
for i in range(f, t):
|
i = f
|
||||||
|
while i < t:
|
||||||
cpn = self.get_component_obj(self.path[i])
|
cpn = self.get_component_obj(self.path[i])
|
||||||
if cpn.component_name.lower() in ["begin", "userfillup"]:
|
if cpn.component_name.lower() in ["begin", "userfillup"]:
|
||||||
thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
|
thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
|
||||||
|
i += 1
|
||||||
else:
|
else:
|
||||||
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
|
for _, ele in cpn.get_input_elements().items():
|
||||||
|
if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i]:
|
||||||
|
self.path.pop(i)
|
||||||
|
t -= 1
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
|
||||||
|
i += 1
|
||||||
for t in thr:
|
for t in thr:
|
||||||
t.result()
|
t.result()
|
||||||
|
|
||||||
@ -316,6 +325,7 @@ class Canvas(Graph):
|
|||||||
"thoughts": self.get_component_thoughts(self.path[i])
|
"thoughts": self.get_component_thoughts(self.path[i])
|
||||||
})
|
})
|
||||||
_run_batch(idx, to)
|
_run_batch(idx, to)
|
||||||
|
to = len(self.path)
|
||||||
# post processing of components invocation
|
# post processing of components invocation
|
||||||
for i in range(idx, to):
|
for i in range(idx, to):
|
||||||
cpn = self.get_component(self.path[i])
|
cpn = self.get_component(self.path[i])
|
||||||
|
|||||||
@ -16,6 +16,13 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
from agent.component.base import ComponentBase, ComponentParamBase
|
from agent.component.base import ComponentBase, ComponentParamBase
|
||||||
|
|
||||||
|
"""
|
||||||
|
class VariableModel(BaseModel):
|
||||||
|
data_type: Annotated[Literal["string", "number", "Object", "Boolean", "Array<string>", "Array<number>", "Array<object>", "Array<boolean>"], Field(default="Array<string>")]
|
||||||
|
input_mode: Annotated[Literal["constant", "variable"], Field(default="constant")]
|
||||||
|
value: Annotated[Any, Field(default=None)]
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
"""
|
||||||
|
|
||||||
class IterationParam(ComponentParamBase):
|
class IterationParam(ComponentParamBase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -216,7 +216,7 @@ class LLM(ComponentBase):
|
|||||||
error: str = ""
|
error: str = ""
|
||||||
output_structure=None
|
output_structure=None
|
||||||
try:
|
try:
|
||||||
output_structure=self._param.outputs['structured']
|
output_structure = None#self._param.outputs['structured']
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
if output_structure:
|
if output_structure:
|
||||||
|
|||||||
@ -49,6 +49,9 @@ class MessageParam(ComponentParamBase):
|
|||||||
class Message(ComponentBase):
|
class Message(ComponentBase):
|
||||||
component_name = "Message"
|
component_name = "Message"
|
||||||
|
|
||||||
|
def get_input_elements(self) -> dict[str, Any]:
|
||||||
|
return self.get_input_elements_from_text("".join(self._param.content))
|
||||||
|
|
||||||
def get_kwargs(self, script:str, kwargs:dict = {}, delimiter:str=None) -> tuple[str, dict[str, str | list | Any]]:
|
def get_kwargs(self, script:str, kwargs:dict = {}, delimiter:str=None) -> tuple[str, dict[str, str | list | Any]]:
|
||||||
for k,v in self.get_input_elements_from_text(script).items():
|
for k,v in self.get_input_elements_from_text(script).items():
|
||||||
if k in kwargs:
|
if k in kwargs:
|
||||||
|
|||||||
@ -16,6 +16,8 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from jinja2 import Template as Jinja2Template
|
from jinja2 import Template as Jinja2Template
|
||||||
from agent.component.base import ComponentParamBase
|
from agent.component.base import ComponentParamBase
|
||||||
from common.connection_utils import timeout
|
from common.connection_utils import timeout
|
||||||
@ -43,6 +45,9 @@ class StringTransformParam(ComponentParamBase):
|
|||||||
class StringTransform(Message, ABC):
|
class StringTransform(Message, ABC):
|
||||||
component_name = "StringTransform"
|
component_name = "StringTransform"
|
||||||
|
|
||||||
|
def get_input_elements(self) -> dict[str, Any]:
|
||||||
|
return self.get_input_elements_from_text(self._param.script)
|
||||||
|
|
||||||
def get_input_form(self) -> dict[str, dict]:
|
def get_input_form(self) -> dict[str, dict]:
|
||||||
if self._param.method == "split":
|
if self._param.method == "split":
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -111,12 +111,14 @@ class SyncLogsService(CommonService):
|
|||||||
return list(query.dicts())
|
return list(query.dicts())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def start(cls, id):
|
def start(cls, id, connector_id):
|
||||||
cls.update_by_id(id, {"status": TaskStatus.RUNNING, "time_started": datetime.now().strftime('%Y-%m-%d %H:%M:%S') })
|
cls.update_by_id(id, {"status": TaskStatus.RUNNING, "time_started": datetime.now().strftime('%Y-%m-%d %H:%M:%S') })
|
||||||
|
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.RUNNING})
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def done(cls, id):
|
def done(cls, id, connector_id):
|
||||||
cls.update_by_id(id, {"status": TaskStatus.DONE})
|
cls.update_by_id(id, {"status": TaskStatus.DONE})
|
||||||
|
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.DONE})
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def schedule(cls, connector_id, kb_id, poll_range_start=None, reindex=False, total_docs_indexed=0):
|
def schedule(cls, connector_id, kb_id, poll_range_start=None, reindex=False, total_docs_indexed=0):
|
||||||
@ -126,6 +128,7 @@ class SyncLogsService(CommonService):
|
|||||||
logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.")
|
logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.")
|
||||||
return None
|
return None
|
||||||
reindex = "1" if reindex else "0"
|
reindex = "1" if reindex else "0"
|
||||||
|
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDUL})
|
||||||
return cls.save(**{
|
return cls.save(**{
|
||||||
"id": get_uuid(),
|
"id": get_uuid(),
|
||||||
"kb_id": kb_id, "status": TaskStatus.SCHEDULE, "connector_id": connector_id,
|
"kb_id": kb_id, "status": TaskStatus.SCHEDULE, "connector_id": connector_id,
|
||||||
@ -142,6 +145,7 @@ class SyncLogsService(CommonService):
|
|||||||
full_exception_trace=cls.model.full_exception_trace + str(e)
|
full_exception_trace=cls.model.full_exception_trace + str(e)
|
||||||
) \
|
) \
|
||||||
.where(cls.model.id == task.id).execute()
|
.where(cls.model.id == task.id).execute()
|
||||||
|
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDUL})
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def increase_docs(cls, id, min_update, max_update, doc_num, err_msg="", error_count=0):
|
def increase_docs(cls, id, min_update, max_update, doc_num, err_msg="", error_count=0):
|
||||||
|
|||||||
@ -51,7 +51,7 @@ class SyncBase:
|
|||||||
self.conf = conf
|
self.conf = conf
|
||||||
|
|
||||||
async def __call__(self, task: dict):
|
async def __call__(self, task: dict):
|
||||||
SyncLogsService.start(task["id"])
|
SyncLogsService.start(task["id"], task["connector_id"])
|
||||||
try:
|
try:
|
||||||
async with task_limiter:
|
async with task_limiter:
|
||||||
with trio.fail_after(task["timeout_secs"]):
|
with trio.fail_after(task["timeout_secs"]):
|
||||||
@ -113,7 +113,7 @@ class S3(SyncBase):
|
|||||||
self.conf["bucket_name"],
|
self.conf["bucket_name"],
|
||||||
begin_info
|
begin_info
|
||||||
))
|
))
|
||||||
SyncLogsService.done(task["id"])
|
SyncLogsService.done(task["id"], task["connector_id"])
|
||||||
return next_update
|
return next_update
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user