Feat: add mechanism to check cancellation in Agent (#10766)

### What problem does this PR solve?

Add mechanism to check cancellation in Agent.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-11-11 17:36:48 +08:00
committed by GitHub
parent d81e4095de
commit 9213568692
36 changed files with 495 additions and 20 deletions

View File

@ -26,7 +26,9 @@ from typing import Any, Union, Tuple
from agent.component import component_class from agent.component import component_class
from agent.component.base import ComponentBase from agent.component.base import ComponentBase
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.task_service import has_canceled
from common.misc_utils import get_uuid, hash_str2int from common.misc_utils import get_uuid, hash_str2int
from common.exceptions import TaskCanceledException
from rag.prompts.generator import chunks_format from rag.prompts.generator import chunks_format
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
@ -126,6 +128,7 @@ class Graph:
self.components[k]["obj"].reset() self.components[k]["obj"].reset()
try: try:
REDIS_CONN.delete(f"{self.task_id}-logs") REDIS_CONN.delete(f"{self.task_id}-logs")
REDIS_CONN.delete(f"{self.task_id}-cancel")
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
@ -196,7 +199,7 @@ class Graph:
if not rest: if not rest:
return root_val return root_val
return self.get_variable_param_value(root_val,rest) return self.get_variable_param_value(root_val,rest)
def get_variable_param_value(self, obj: Any, path: str) -> Any: def get_variable_param_value(self, obj: Any, path: str) -> Any:
cur = obj cur = obj
if not path: if not path:
@ -215,6 +218,17 @@ class Graph:
cur = getattr(cur, key, None) cur = getattr(cur, key, None)
return cur return cur
def is_canceled(self) -> bool:
return has_canceled(self.task_id)
def cancel_task(self) -> bool:
try:
REDIS_CONN.set(f"{self.task_id}-cancel", "x")
except Exception as e:
logging.exception(e)
return False
return True
class Canvas(Graph): class Canvas(Graph):
@ -239,7 +253,7 @@ class Canvas(Graph):
"sys.conversation_turns": 0, "sys.conversation_turns": 0,
"sys.files": [] "sys.files": []
} }
self.retrieval = self.dsl["retrieval"] self.retrieval = self.dsl["retrieval"]
self.memory = self.dsl.get("memory", []) self.memory = self.dsl.get("memory", [])
@ -311,10 +325,20 @@ class Canvas(Graph):
self.path.append("begin") self.path.append("begin")
self.retrieval.append({"chunks": [], "doc_aggs": []}) self.retrieval.append({"chunks": [], "doc_aggs": []})
if self.is_canceled():
msg = f"Task {self.task_id} has been canceled before starting."
logging.info(msg)
raise TaskCanceledException(msg)
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")}) yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
self.retrieval.append({"chunks": {}, "doc_aggs": {}}) self.retrieval.append({"chunks": {}, "doc_aggs": {}})
def _run_batch(f, t): def _run_batch(f, t):
if self.is_canceled():
msg = f"Task {self.task_id} has been canceled during batch execution."
logging.info(msg)
raise TaskCanceledException(msg)
with ThreadPoolExecutor(max_workers=5) as executor: with ThreadPoolExecutor(max_workers=5) as executor:
thr = [] thr = []
i = f i = f
@ -473,6 +497,14 @@ class Canvas(Graph):
"created_at": st, "created_at": st,
}) })
self.history.append(("assistant", self.get_component_obj(self.path[-1]).output())) self.history.append(("assistant", self.get_component_obj(self.path[-1]).output()))
elif "Task has been canceled" in self.error:
yield decorate("workflow_finished",
{
"inputs": kwargs.get("inputs"),
"outputs": "Task has been canceled",
"elapsed_time": time.perf_counter() - st,
"created_at": st,
})
def is_reff(self, exp: str) -> bool: def is_reff(self, exp: str) -> bool:
exp = exp.strip("{").strip("}") exp = exp.strip("{").strip("}")

View File

@ -139,6 +139,9 @@ class Agent(LLM, ToolBase):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("Agent processing"):
return
if kwargs.get("user_prompt"): if kwargs.get("user_prompt"):
usr_pmt = "" usr_pmt = ""
if kwargs.get("reasoning"): if kwargs.get("reasoning"):
@ -152,6 +155,8 @@ class Agent(LLM, ToolBase):
self._param.prompts = [{"role": "user", "content": usr_pmt}] self._param.prompts = [{"role": "user", "content": usr_pmt}]
if not self.tools: if not self.tools:
if self.check_if_canceled("Agent processing"):
return
return LLM._invoke(self, **kwargs) return LLM._invoke(self, **kwargs)
prompt, msg, user_defined_prompt = self._prepare_prompt_variables() prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
@ -171,6 +176,8 @@ class Agent(LLM, ToolBase):
use_tools = [] use_tools = []
ans = "" ans = ""
for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt): for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
if self.check_if_canceled("Agent processing"):
return
ans += delta_ans ans += delta_ans
if ans.find("**ERROR**") >= 0: if ans.find("**ERROR**") >= 0:
@ -191,12 +198,16 @@ class Agent(LLM, ToolBase):
answer_without_toolcall = "" answer_without_toolcall = ""
use_tools = [] use_tools = []
for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt): for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
if self.check_if_canceled("Agent streaming"):
return
if delta_ans.find("**ERROR**") >= 0: if delta_ans.find("**ERROR**") >= 0:
if self.get_exception_default_value(): if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value()) self.set_output("content", self.get_exception_default_value())
yield self.get_exception_default_value() yield self.get_exception_default_value()
else: else:
self.set_output("_ERROR", delta_ans) self.set_output("_ERROR", delta_ans)
return
answer_without_toolcall += delta_ans answer_without_toolcall += delta_ans
yield delta_ans yield delta_ans
@ -271,6 +282,8 @@ class Agent(LLM, ToolBase):
st = timer() st = timer()
txt = "" txt = ""
for delta_ans in self._gen_citations(entire_txt): for delta_ans in self._gen_citations(entire_txt):
if self.check_if_canceled("Agent streaming"):
return
yield delta_ans, 0 yield delta_ans, 0
txt += delta_ans txt += delta_ans
@ -286,6 +299,8 @@ class Agent(LLM, ToolBase):
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt) task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st) self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
for _ in range(self._param.max_rounds + 1): for _ in range(self._param.max_rounds + 1):
if self.check_if_canceled("Agent streaming"):
return
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt) response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
# self.callback("next_step", {}, str(response)[:256]+"...") # self.callback("next_step", {}, str(response)[:256]+"...")
token_count += tk token_count += tk
@ -333,6 +348,8 @@ Instructions:
6. Focus on delivering VALUE with the information already gathered 6. Focus on delivering VALUE with the information already gathered
Respond immediately with your final comprehensive answer. Respond immediately with your final comprehensive answer.
""" """
if self.check_if_canceled("Agent final instruction"):
return
append_user_content(hist, final_instruction) append_user_content(hist, final_instruction)
for txt, tkcnt in complete(): for txt, tkcnt in complete():

View File

@ -417,6 +417,20 @@ class ComponentBase(ABC):
self._param = param self._param = param
self._param.check() self._param.check()
def is_canceled(self) -> bool:
return self._canvas.is_canceled()
def check_if_canceled(self, message: str = "") -> bool:
if self.is_canceled():
task_id = getattr(self._canvas, 'task_id', 'unknown')
log_message = f"Task {task_id} has been canceled"
if message:
log_message += f" during {message}"
logging.info(log_message)
self.set_output("_ERROR", "Task has been canceled")
return True
return False
def invoke(self, **kwargs) -> dict[str, Any]: def invoke(self, **kwargs) -> dict[str, Any]:
self.set_output("_created_time", time.perf_counter()) self.set_output("_created_time", time.perf_counter())
try: try:

View File

@ -37,7 +37,13 @@ class Begin(UserFillUp):
component_name = "Begin" component_name = "Begin"
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("Begin processing"):
return
for k, v in kwargs.get("inputs", {}).items(): for k, v in kwargs.get("inputs", {}).items():
if self.check_if_canceled("Begin processing"):
return
if isinstance(v, dict) and v.get("type", "").lower().find("file") >=0: if isinstance(v, dict) and v.get("type", "").lower().find("file") >=0:
if v.get("optional") and v.get("value", None) is None: if v.get("optional") and v.get("value", None) is None:
v = None v = None

View File

@ -98,6 +98,9 @@ class Categorize(LLM, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("Categorize processing"):
return
msg = self._canvas.get_history(self._param.message_history_window_size) msg = self._canvas.get_history(self._param.message_history_window_size)
if not msg: if not msg:
msg = [{"role": "user", "content": ""}] msg = [{"role": "user", "content": ""}]
@ -114,10 +117,18 @@ class Categorize(LLM, ABC):
---- Real Data ---- ---- Real Data ----
{} {}
""".format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg])) """.format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg]))
if self.check_if_canceled("Categorize processing"):
return
ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf()) ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
logging.info(f"input: {user_prompt}, answer: {str(ans)}") logging.info(f"input: {user_prompt}, answer: {str(ans)}")
if ERROR_PREFIX in ans: if ERROR_PREFIX in ans:
raise Exception(ans) raise Exception(ans)
if self.check_if_canceled("Categorize processing"):
return
# Count the number of times each category appears in the answer. # Count the number of times each category appears in the answer.
category_counts = {} category_counts = {}
for c in self._param.category_description.keys(): for c in self._param.category_description.keys():

View File

@ -35,6 +35,9 @@ class UserFillUp(ComponentBase):
component_name = "UserFillUp" component_name = "UserFillUp"
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("UserFillUp processing"):
return
if self._param.enable_tips: if self._param.enable_tips:
content = self._param.tips content = self._param.tips
for k, v in self.get_input_elements_from_text(self._param.tips).items(): for k, v in self.get_input_elements_from_text(self._param.tips).items():
@ -58,9 +61,9 @@ class UserFillUp(ComponentBase):
self.set_output("tips", content) self.set_output("tips", content)
for k, v in kwargs.get("inputs", {}).items(): for k, v in kwargs.get("inputs", {}).items():
if self.check_if_canceled("UserFillUp processing"):
return
self.set_output(k, v) self.set_output(k, v)
def thoughts(self) -> str: def thoughts(self) -> str:
return "Waiting for your input..." return "Waiting for your input..."

View File

@ -56,6 +56,9 @@ class Invoke(ComponentBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("Invoke processing"):
return
args = {} args = {}
for para in self._param.variables: for para in self._param.variables:
if para.get("value"): if para.get("value"):
@ -89,6 +92,9 @@ class Invoke(ComponentBase, ABC):
last_e = "" last_e = ""
for _ in range(self._param.max_retries + 1): for _ in range(self._param.max_retries + 1):
if self.check_if_canceled("Invoke processing"):
return
try: try:
if method == "get": if method == "get":
response = requests.get(url=url, params=args, headers=headers, proxies=proxies, timeout=self._param.timeout) response = requests.get(url=url, params=args, headers=headers, proxies=proxies, timeout=self._param.timeout)
@ -121,6 +127,9 @@ class Invoke(ComponentBase, ABC):
return self.output("result") return self.output("result")
except Exception as e: except Exception as e:
if self.check_if_canceled("Invoke processing"):
return
last_e = e last_e = e
logging.exception(f"Http request error: {e}") logging.exception(f"Http request error: {e}")
time.sleep(self._param.delay_after_error) time.sleep(self._param.delay_after_error)

View File

@ -56,6 +56,9 @@ class Iteration(ComponentBase, ABC):
return cid return cid
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("Iteration processing"):
return
arr = self._canvas.get_variable_value(self._param.items_ref) arr = self._canvas.get_variable_value(self._param.items_ref)
if not isinstance(arr, list): if not isinstance(arr, list):
self.set_output("_ERROR", self._param.items_ref + " must be an array, but its type is "+str(type(arr))) self.set_output("_ERROR", self._param.items_ref + " must be an array, but its type is "+str(type(arr)))

View File

@ -33,6 +33,9 @@ class IterationItem(ComponentBase, ABC):
self._idx = 0 self._idx = 0
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("IterationItem processing"):
return
parent = self.get_parent() parent = self.get_parent()
arr = self._canvas.get_variable_value(parent._param.items_ref) arr = self._canvas.get_variable_value(parent._param.items_ref)
if not isinstance(arr, list): if not isinstance(arr, list):
@ -40,12 +43,17 @@ class IterationItem(ComponentBase, ABC):
raise Exception(parent._param.items_ref + " must be an array, but its type is "+str(type(arr))) raise Exception(parent._param.items_ref + " must be an array, but its type is "+str(type(arr)))
if self._idx > 0: if self._idx > 0:
if self.check_if_canceled("IterationItem processing"):
return
self.output_collation() self.output_collation()
if self._idx >= len(arr): if self._idx >= len(arr):
self._idx = -1 self._idx = -1
return return
if self.check_if_canceled("IterationItem processing"):
return
self.set_output("item", arr[self._idx]) self.set_output("item", arr[self._idx])
self.set_output("index", self._idx) self.set_output("index", self._idx)
@ -80,4 +88,4 @@ class IterationItem(ComponentBase, ABC):
return self._idx == -1 return self._idx == -1
def thoughts(self) -> str: def thoughts(self) -> str:
return "Next turn..." return "Next turn..."

View File

@ -207,6 +207,9 @@ class LLM(ComponentBase):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("LLM processing"):
return
def clean_formated_answer(ans: str) -> str: def clean_formated_answer(ans: str) -> str:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL) ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
@ -223,6 +226,9 @@ class LLM(ComponentBase):
schema=json.dumps(output_structure, ensure_ascii=False, indent=2) schema=json.dumps(output_structure, ensure_ascii=False, indent=2)
prompt += structured_output_prompt(schema) prompt += structured_output_prompt(schema)
for _ in range(self._param.max_retries+1): for _ in range(self._param.max_retries+1):
if self.check_if_canceled("LLM processing"):
return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
error = "" error = ""
ans = self._generate(msg) ans = self._generate(msg)
@ -248,6 +254,9 @@ class LLM(ComponentBase):
return return
for _ in range(self._param.max_retries+1): for _ in range(self._param.max_retries+1):
if self.check_if_canceled("LLM processing"):
return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
error = "" error = ""
ans = self._generate(msg) ans = self._generate(msg)
@ -269,6 +278,9 @@ class LLM(ComponentBase):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer = "" answer = ""
for ans in self._generate_streamly(msg): for ans in self._generate_streamly(msg):
if self.check_if_canceled("LLM streaming"):
return
if ans.find("**ERROR**") >= 0: if ans.find("**ERROR**") >= 0:
if self.get_exception_default_value(): if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value()) self.set_output("content", self.get_exception_default_value())
@ -287,4 +299,4 @@ class LLM(ComponentBase):
def thoughts(self) -> str: def thoughts(self) -> str:
_, msg,_ = self._prepare_prompt_variables() _, msg,_ = self._prepare_prompt_variables()
return "⌛Give me a moment—starting from: \n\n" + re.sub(r"(User's query:|[\\]+)", '', msg[-1]['content'], flags=re.DOTALL) + "\n\nIll figure out our best next move." return "⌛Give me a moment—starting from: \n\n" + re.sub(r"(User's query:|[\\]+)", '', msg[-1]['content'], flags=re.DOTALL) + "\n\nIll figure out our best next move."

View File

@ -89,6 +89,9 @@ class Message(ComponentBase):
all_content = "" all_content = ""
cache = {} cache = {}
for r in re.finditer(self.variable_ref_patt, rand_cnt, flags=re.DOTALL): for r in re.finditer(self.variable_ref_patt, rand_cnt, flags=re.DOTALL):
if self.check_if_canceled("Message streaming"):
return
all_content += rand_cnt[s: r.start()] all_content += rand_cnt[s: r.start()]
yield rand_cnt[s: r.start()] yield rand_cnt[s: r.start()]
s = r.end() s = r.end()
@ -104,6 +107,9 @@ class Message(ComponentBase):
if isinstance(v, partial): if isinstance(v, partial):
cnt = "" cnt = ""
for t in v(): for t in v():
if self.check_if_canceled("Message streaming"):
return
all_content += t all_content += t
cnt += t cnt += t
yield t yield t
@ -111,7 +117,7 @@ class Message(ComponentBase):
continue continue
elif not isinstance(v, str): elif not isinstance(v, str):
try: try:
v = json.dumps(v, ensure_ascii=False, indent=2) v = json.dumps(v, ensure_ascii=False)
except Exception: except Exception:
v = str(v) v = str(v)
yield v yield v
@ -120,6 +126,9 @@ class Message(ComponentBase):
cache[exp] = v cache[exp] = v
if s < len(rand_cnt): if s < len(rand_cnt):
if self.check_if_canceled("Message streaming"):
return
all_content += rand_cnt[s: ] all_content += rand_cnt[s: ]
yield rand_cnt[s: ] yield rand_cnt[s: ]
@ -133,6 +142,9 @@ class Message(ComponentBase):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("Message processing"):
return
rand_cnt = random.choice(self._param.content) rand_cnt = random.choice(self._param.content)
if self._param.stream and not self._is_jinjia2(rand_cnt): if self._param.stream and not self._is_jinjia2(rand_cnt):
self.set_output("content", partial(self._stream, rand_cnt)) self.set_output("content", partial(self._stream, rand_cnt))
@ -145,6 +157,9 @@ class Message(ComponentBase):
except Exception: except Exception:
pass pass
if self.check_if_canceled("Message processing"):
return
for n, v in kwargs.items(): for n, v in kwargs.items():
content = re.sub(n, v, content) content = re.sub(n, v, content)

View File

@ -63,17 +63,24 @@ class StringTransform(Message, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("StringTransform processing"):
return
if self._param.method == "split": if self._param.method == "split":
self._split(kwargs.get("line")) self._split(kwargs.get("line"))
else: else:
self._merge(kwargs) self._merge(kwargs)
def _split(self, line:str|None = None): def _split(self, line:str|None = None):
if self.check_if_canceled("StringTransform split processing"):
return
var = self._canvas.get_variable_value(self._param.split_ref) if not line else line var = self._canvas.get_variable_value(self._param.split_ref) if not line else line
if not var: if not var:
var = "" var = ""
assert isinstance(var, str), "The input variable is not a string: {}".format(type(var)) assert isinstance(var, str), "The input variable is not a string: {}".format(type(var))
self.set_input_value(self._param.split_ref, var) self.set_input_value(self._param.split_ref, var)
res = [] res = []
for i,s in enumerate(re.split(r"(%s)"%("|".join([re.escape(d) for d in self._param.delimiters])), var, flags=re.DOTALL)): for i,s in enumerate(re.split(r"(%s)"%("|".join([re.escape(d) for d in self._param.delimiters])), var, flags=re.DOTALL)):
if i % 2 == 1: if i % 2 == 1:
@ -82,6 +89,9 @@ class StringTransform(Message, ABC):
self.set_output("result", res) self.set_output("result", res)
def _merge(self, kwargs:dict[str, str] = {}): def _merge(self, kwargs:dict[str, str] = {}):
if self.check_if_canceled("StringTransform merge processing"):
return
script = self._param.script script = self._param.script
script, kwargs = self.get_kwargs(script, kwargs, self._param.delimiters[0]) script, kwargs = self.get_kwargs(script, kwargs, self._param.delimiters[0])

View File

@ -63,9 +63,18 @@ class Switch(ComponentBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("Switch processing"):
return
for cond in self._param.conditions: for cond in self._param.conditions:
if self.check_if_canceled("Switch processing"):
return
res = [] res = []
for item in cond["items"]: for item in cond["items"]:
if self.check_if_canceled("Switch processing"):
return
if not item["cpn_id"]: if not item["cpn_id"]:
continue continue
cpn_v = self._canvas.get_variable_value(item["cpn_id"]) cpn_v = self._canvas.get_variable_value(item["cpn_id"])
@ -128,4 +137,4 @@ class Switch(ComponentBase, ABC):
raise ValueError('Not supported operator' + operator) raise ValueError('Not supported operator' + operator)
def thoughts(self) -> str: def thoughts(self) -> str:
return "Im weighing a few options and will pick the next step shortly." return "Im weighing a few options and will pick the next step shortly."

View File

@ -63,12 +63,18 @@ class ArXiv(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("ArXiv processing"):
return
if not kwargs.get("query"): if not kwargs.get("query"):
self.set_output("formalized_content", "") self.set_output("formalized_content", "")
return "" return ""
last_e = "" last_e = ""
for _ in range(self._param.max_retries+1): for _ in range(self._param.max_retries+1):
if self.check_if_canceled("ArXiv processing"):
return
try: try:
sort_choices = {"relevance": arxiv.SortCriterion.Relevance, sort_choices = {"relevance": arxiv.SortCriterion.Relevance,
"lastUpdatedDate": arxiv.SortCriterion.LastUpdatedDate, "lastUpdatedDate": arxiv.SortCriterion.LastUpdatedDate,
@ -79,12 +85,20 @@ class ArXiv(ToolBase, ABC):
max_results=self._param.top_n, max_results=self._param.top_n,
sort_by=sort_choices[self._param.sort_by] sort_by=sort_choices[self._param.sort_by]
) )
self._retrieve_chunks(list(arxiv_client.results(search)), results = list(arxiv_client.results(search))
if self.check_if_canceled("ArXiv processing"):
return
self._retrieve_chunks(results,
get_title=lambda r: r.title, get_title=lambda r: r.title,
get_url=lambda r: r.pdf_url, get_url=lambda r: r.pdf_url,
get_content=lambda r: r.summary) get_content=lambda r: r.summary)
return self.output("formalized_content") return self.output("formalized_content")
except Exception as e: except Exception as e:
if self.check_if_canceled("ArXiv processing"):
return
last_e = e last_e = e
logging.exception(f"ArXiv error: {e}") logging.exception(f"ArXiv error: {e}")
time.sleep(self._param.delay_after_error) time.sleep(self._param.delay_after_error)

View File

@ -125,6 +125,9 @@ class ToolBase(ComponentBase):
return self._param.get_meta() return self._param.get_meta()
def invoke(self, **kwargs): def invoke(self, **kwargs):
if self.check_if_canceled("Tool processing"):
return
self.set_output("_created_time", time.perf_counter()) self.set_output("_created_time", time.perf_counter())
try: try:
res = self._invoke(**kwargs) res = self._invoke(**kwargs)
@ -170,4 +173,4 @@ class ToolBase(ComponentBase):
self.set_output("formalized_content", "\n".join(kb_prompt({"chunks": chunks, "doc_aggs": aggs}, 200000, True))) self.set_output("formalized_content", "\n".join(kb_prompt({"chunks": chunks, "doc_aggs": aggs}, 200000, True)))
def thoughts(self) -> str: def thoughts(self) -> str:
return self._canvas.get_component_name(self._id) + " is running..." return self._canvas.get_component_name(self._id) + " is running..."

View File

@ -131,10 +131,14 @@ class CodeExec(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("CodeExec processing"):
return
lang = kwargs.get("lang", self._param.lang) lang = kwargs.get("lang", self._param.lang)
script = kwargs.get("script", self._param.script) script = kwargs.get("script", self._param.script)
arguments = {} arguments = {}
for k, v in self._param.arguments.items(): for k, v in self._param.arguments.items():
if kwargs.get(k): if kwargs.get(k):
arguments[k] = kwargs[k] arguments[k] = kwargs[k]
continue continue
@ -149,15 +153,28 @@ class CodeExec(ToolBase, ABC):
def _execute_code(self, language: str, code: str, arguments: dict): def _execute_code(self, language: str, code: str, arguments: dict):
import requests import requests
if self.check_if_canceled("CodeExec execution"):
return
try: try:
code_b64 = self._encode_code(code) code_b64 = self._encode_code(code)
code_req = CodeExecutionRequest(code_b64=code_b64, language=language, arguments=arguments).model_dump() code_req = CodeExecutionRequest(code_b64=code_b64, language=language, arguments=arguments).model_dump()
except Exception as e: except Exception as e:
if self.check_if_canceled("CodeExec execution"):
return
self.set_output("_ERROR", "construct code request error: " + str(e)) self.set_output("_ERROR", "construct code request error: " + str(e))
try: try:
if self.check_if_canceled("CodeExec execution"):
return "Task has been canceled"
resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run, code_req: {code_req}, resp.status_code {resp.status_code}:") logging.info(f"http://{settings.SANDBOX_HOST}:9385/run, code_req: {code_req}, resp.status_code {resp.status_code}:")
if self.check_if_canceled("CodeExec execution"):
return "Task has been canceled"
if resp.status_code != 200: if resp.status_code != 200:
resp.raise_for_status() resp.raise_for_status()
body = resp.json() body = resp.json()
@ -173,16 +190,25 @@ class CodeExec(ToolBase, ABC):
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run -> {rt}") logging.info(f"http://{settings.SANDBOX_HOST}:9385/run -> {rt}")
if isinstance(rt, tuple): if isinstance(rt, tuple):
for i, (k, o) in enumerate(self._param.outputs.items()): for i, (k, o) in enumerate(self._param.outputs.items()):
if self.check_if_canceled("CodeExec execution"):
return
if k.find("_") == 0: if k.find("_") == 0:
continue continue
o["value"] = rt[i] o["value"] = rt[i]
elif isinstance(rt, dict): elif isinstance(rt, dict):
for i, (k, o) in enumerate(self._param.outputs.items()): for i, (k, o) in enumerate(self._param.outputs.items()):
if self.check_if_canceled("CodeExec execution"):
return
if k not in rt or k.find("_") == 0: if k not in rt or k.find("_") == 0:
continue continue
o["value"] = rt[k] o["value"] = rt[k]
else: else:
for i, (k, o) in enumerate(self._param.outputs.items()): for i, (k, o) in enumerate(self._param.outputs.items()):
if self.check_if_canceled("CodeExec execution"):
return
if k.find("_") == 0: if k.find("_") == 0:
continue continue
o["value"] = rt o["value"] = rt
@ -190,6 +216,9 @@ class CodeExec(ToolBase, ABC):
self.set_output("_ERROR", "There is no response from sandbox") self.set_output("_ERROR", "There is no response from sandbox")
except Exception as e: except Exception as e:
if self.check_if_canceled("CodeExec execution"):
return
self.set_output("_ERROR", "Exception executing code: " + str(e)) self.set_output("_ERROR", "Exception executing code: " + str(e))
return self.output() return self.output()

View File

@ -29,7 +29,7 @@ class CrawlerParam(ToolParamBase):
super().__init__() super().__init__()
self.proxy = None self.proxy = None
self.extract_type = "markdown" self.extract_type = "markdown"
def check(self): def check(self):
self.check_valid_value(self.extract_type, "Type of content from the crawler", ['html', 'markdown', 'content']) self.check_valid_value(self.extract_type, "Type of content from the crawler", ['html', 'markdown', 'content'])
@ -47,18 +47,24 @@ class Crawler(ToolBase, ABC):
result = asyncio.run(self.get_web(ans)) result = asyncio.run(self.get_web(ans))
return Crawler.be_output(result) return Crawler.be_output(result)
except Exception as e: except Exception as e:
return Crawler.be_output(f"An unexpected error occurred: {str(e)}") return Crawler.be_output(f"An unexpected error occurred: {str(e)}")
async def get_web(self, url): async def get_web(self, url):
if self.check_if_canceled("Crawler async operation"):
return
proxy = self._param.proxy if self._param.proxy else None proxy = self._param.proxy if self._param.proxy else None
async with AsyncWebCrawler(verbose=True, proxy=proxy) as crawler: async with AsyncWebCrawler(verbose=True, proxy=proxy) as crawler:
result = await crawler.arun( result = await crawler.arun(
url=url, url=url,
bypass_cache=True bypass_cache=True
) )
if self.check_if_canceled("Crawler async operation"):
return
if self._param.extract_type == 'html': if self._param.extract_type == 'html':
return result.cleaned_html return result.cleaned_html
elif self._param.extract_type == 'markdown': elif self._param.extract_type == 'markdown':

View File

@ -46,11 +46,16 @@ class DeepL(ComponentBase, ABC):
component_name = "DeepL" component_name = "DeepL"
def _run(self, history, **kwargs): def _run(self, history, **kwargs):
if self.check_if_canceled("DeepL processing"):
return
ans = self.get_input() ans = self.get_input()
ans = " - ".join(ans["content"]) if "content" in ans else "" ans = " - ".join(ans["content"]) if "content" in ans else ""
if not ans: if not ans:
return DeepL.be_output("") return DeepL.be_output("")
if self.check_if_canceled("DeepL processing"):
return
try: try:
translator = deepl.Translator(self._param.auth_key) translator = deepl.Translator(self._param.auth_key)
result = translator.translate_text(ans, source_lang=self._param.source_lang, result = translator.translate_text(ans, source_lang=self._param.source_lang,
@ -58,4 +63,6 @@ class DeepL(ComponentBase, ABC):
return DeepL.be_output(result.text) return DeepL.be_output(result.text)
except Exception as e: except Exception as e:
if self.check_if_canceled("DeepL processing"):
return
DeepL.be_output("**Error**:" + str(e)) DeepL.be_output("**Error**:" + str(e))

View File

@ -75,17 +75,30 @@ class DuckDuckGo(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("DuckDuckGo processing"):
return
if not kwargs.get("query"): if not kwargs.get("query"):
self.set_output("formalized_content", "") self.set_output("formalized_content", "")
return "" return ""
last_e = "" last_e = ""
for _ in range(self._param.max_retries+1): for _ in range(self._param.max_retries+1):
if self.check_if_canceled("DuckDuckGo processing"):
return
try: try:
if kwargs.get("topic", "general") == "general": if kwargs.get("topic", "general") == "general":
with DDGS() as ddgs: with DDGS() as ddgs:
if self.check_if_canceled("DuckDuckGo processing"):
return
# {'title': '', 'href': '', 'body': ''} # {'title': '', 'href': '', 'body': ''}
duck_res = ddgs.text(kwargs["query"], max_results=self._param.top_n) duck_res = ddgs.text(kwargs["query"], max_results=self._param.top_n)
if self.check_if_canceled("DuckDuckGo processing"):
return
self._retrieve_chunks(duck_res, self._retrieve_chunks(duck_res,
get_title=lambda r: r["title"], get_title=lambda r: r["title"],
get_url=lambda r: r.get("href", r.get("url")), get_url=lambda r: r.get("href", r.get("url")),
@ -94,8 +107,15 @@ class DuckDuckGo(ToolBase, ABC):
return self.output("formalized_content") return self.output("formalized_content")
else: else:
with DDGS() as ddgs: with DDGS() as ddgs:
if self.check_if_canceled("DuckDuckGo processing"):
return
# {'date': '', 'title': '', 'body': '', 'url': '', 'image': '', 'source': ''} # {'date': '', 'title': '', 'body': '', 'url': '', 'image': '', 'source': ''}
duck_res = ddgs.news(kwargs["query"], max_results=self._param.top_n) duck_res = ddgs.news(kwargs["query"], max_results=self._param.top_n)
if self.check_if_canceled("DuckDuckGo processing"):
return
self._retrieve_chunks(duck_res, self._retrieve_chunks(duck_res,
get_title=lambda r: r["title"], get_title=lambda r: r["title"],
get_url=lambda r: r.get("href", r.get("url")), get_url=lambda r: r.get("href", r.get("url")),
@ -103,6 +123,9 @@ class DuckDuckGo(ToolBase, ABC):
self.set_output("json", duck_res) self.set_output("json", duck_res)
return self.output("formalized_content") return self.output("formalized_content")
except Exception as e: except Exception as e:
if self.check_if_canceled("DuckDuckGo processing"):
return
last_e = e last_e = e
logging.exception(f"DuckDuckGo error: {e}") logging.exception(f"DuckDuckGo error: {e}")
time.sleep(self._param.delay_after_error) time.sleep(self._param.delay_after_error)

View File

@ -101,19 +101,27 @@ class Email(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("Email processing"):
return
if not kwargs.get("to_email"): if not kwargs.get("to_email"):
self.set_output("success", False) self.set_output("success", False)
return "" return ""
last_e = "" last_e = ""
for _ in range(self._param.max_retries+1): for _ in range(self._param.max_retries+1):
if self.check_if_canceled("Email processing"):
return
try: try:
# Parse JSON string passed from upstream # Parse JSON string passed from upstream
email_data = kwargs email_data = kwargs
# Validate required fields # Validate required fields
if "to_email" not in email_data: if "to_email" not in email_data:
return Email.be_output("Missing required field: to_email") self.set_output("_ERROR", "Missing required field: to_email")
self.set_output("success", False)
return False
# Create email object # Create email object
msg = MIMEMultipart('alternative') msg = MIMEMultipart('alternative')
@ -133,6 +141,9 @@ class Email(ToolBase, ABC):
# Connect to SMTP server and send # Connect to SMTP server and send
logging.info(f"Connecting to SMTP server {self._param.smtp_server}:{self._param.smtp_port}") logging.info(f"Connecting to SMTP server {self._param.smtp_server}:{self._param.smtp_port}")
if self.check_if_canceled("Email processing"):
return
context = smtplib.ssl.create_default_context() context = smtplib.ssl.create_default_context()
with smtplib.SMTP(self._param.smtp_server, self._param.smtp_port) as server: with smtplib.SMTP(self._param.smtp_server, self._param.smtp_port) as server:
server.ehlo() server.ehlo()
@ -149,6 +160,10 @@ class Email(ToolBase, ABC):
# Send email # Send email
logging.info(f"Sending email to recipients: {recipients}") logging.info(f"Sending email to recipients: {recipients}")
if self.check_if_canceled("Email processing"):
return
try: try:
server.send_message(msg, self._param.email, recipients) server.send_message(msg, self._param.email, recipients)
success = True success = True

View File

@ -81,6 +81,8 @@ class ExeSQL(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("ExeSQL processing"):
return
def convert_decimals(obj): def convert_decimals(obj):
from decimal import Decimal from decimal import Decimal
@ -96,6 +98,9 @@ class ExeSQL(ToolBase, ABC):
if not sql: if not sql:
raise Exception("SQL for `ExeSQL` MUST not be empty.") raise Exception("SQL for `ExeSQL` MUST not be empty.")
if self.check_if_canceled("ExeSQL processing"):
return
vars = self.get_input_elements_from_text(sql) vars = self.get_input_elements_from_text(sql)
args = {} args = {}
for k, o in vars.items(): for k, o in vars.items():
@ -108,6 +113,9 @@ class ExeSQL(ToolBase, ABC):
self.set_input_value(k, args[k]) self.set_input_value(k, args[k])
sql = self.string_format(sql, args) sql = self.string_format(sql, args)
if self.check_if_canceled("ExeSQL processing"):
return
sqls = sql.split(";") sqls = sql.split(";")
if self._param.db_type in ["mysql", "mariadb"]: if self._param.db_type in ["mysql", "mariadb"]:
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host, db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
@ -181,6 +189,10 @@ class ExeSQL(ToolBase, ABC):
sql_res = [] sql_res = []
formalized_content = [] formalized_content = []
for single_sql in sqls: for single_sql in sqls:
if self.check_if_canceled("ExeSQL processing"):
ibm_db.close(conn)
return
single_sql = single_sql.replace("```", "").strip() single_sql = single_sql.replace("```", "").strip()
if not single_sql: if not single_sql:
continue continue
@ -190,6 +202,9 @@ class ExeSQL(ToolBase, ABC):
rows = [] rows = []
row = ibm_db.fetch_assoc(stmt) row = ibm_db.fetch_assoc(stmt)
while row and len(rows) < self._param.max_records: while row and len(rows) < self._param.max_records:
if self.check_if_canceled("ExeSQL processing"):
ibm_db.close(conn)
return
rows.append(row) rows.append(row)
row = ibm_db.fetch_assoc(stmt) row = ibm_db.fetch_assoc(stmt)
@ -220,6 +235,11 @@ class ExeSQL(ToolBase, ABC):
sql_res = [] sql_res = []
formalized_content = [] formalized_content = []
for single_sql in sqls: for single_sql in sqls:
if self.check_if_canceled("ExeSQL processing"):
cursor.close()
db.close()
return
single_sql = single_sql.replace('```','') single_sql = single_sql.replace('```','')
if not single_sql: if not single_sql:
continue continue
@ -244,6 +264,9 @@ class ExeSQL(ToolBase, ABC):
sql_res.append(convert_decimals(single_res.to_dict(orient='records'))) sql_res.append(convert_decimals(single_res.to_dict(orient='records')))
formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f")) formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f"))
cursor.close()
db.close()
self.set_output("json", sql_res) self.set_output("json", sql_res)
self.set_output("formalized_content", "\n\n".join(formalized_content)) self.set_output("formalized_content", "\n\n".join(formalized_content))
return self.output("formalized_content") return self.output("formalized_content")

View File

@ -59,17 +59,27 @@ class GitHub(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("GitHub processing"):
return
if not kwargs.get("query"): if not kwargs.get("query"):
self.set_output("formalized_content", "") self.set_output("formalized_content", "")
return "" return ""
last_e = "" last_e = ""
for _ in range(self._param.max_retries+1): for _ in range(self._param.max_retries+1):
if self.check_if_canceled("GitHub processing"):
return
try: try:
url = 'https://api.github.com/search/repositories?q=' + kwargs["query"] + '&sort=stars&order=desc&per_page=' + str( url = 'https://api.github.com/search/repositories?q=' + kwargs["query"] + '&sort=stars&order=desc&per_page=' + str(
self._param.top_n) self._param.top_n)
headers = {"Content-Type": "application/vnd.github+json", "X-GitHub-Api-Version": '2022-11-28'} headers = {"Content-Type": "application/vnd.github+json", "X-GitHub-Api-Version": '2022-11-28'}
response = requests.get(url=url, headers=headers).json() response = requests.get(url=url, headers=headers).json()
if self.check_if_canceled("GitHub processing"):
return
self._retrieve_chunks(response['items'], self._retrieve_chunks(response['items'],
get_title=lambda r: r["name"], get_title=lambda r: r["name"],
get_url=lambda r: r["html_url"], get_url=lambda r: r["html_url"],
@ -77,6 +87,9 @@ class GitHub(ToolBase, ABC):
self.set_output("json", response['items']) self.set_output("json", response['items'])
return self.output("formalized_content") return self.output("formalized_content")
except Exception as e: except Exception as e:
if self.check_if_canceled("GitHub processing"):
return
last_e = e last_e = e
logging.exception(f"GitHub error: {e}") logging.exception(f"GitHub error: {e}")
time.sleep(self._param.delay_after_error) time.sleep(self._param.delay_after_error)

View File

@ -118,6 +118,9 @@ class Google(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("Google processing"):
return
if not kwargs.get("q"): if not kwargs.get("q"):
self.set_output("formalized_content", "") self.set_output("formalized_content", "")
return "" return ""
@ -132,8 +135,15 @@ class Google(ToolBase, ABC):
} }
last_e = "" last_e = ""
for _ in range(self._param.max_retries+1): for _ in range(self._param.max_retries+1):
if self.check_if_canceled("Google processing"):
return
try: try:
search = GoogleSearch(params).get_dict() search = GoogleSearch(params).get_dict()
if self.check_if_canceled("Google processing"):
return
self._retrieve_chunks(search["organic_results"], self._retrieve_chunks(search["organic_results"],
get_title=lambda r: r["title"], get_title=lambda r: r["title"],
get_url=lambda r: r["link"], get_url=lambda r: r["link"],
@ -142,6 +152,9 @@ class Google(ToolBase, ABC):
self.set_output("json", search["organic_results"]) self.set_output("json", search["organic_results"])
return self.output("formalized_content") return self.output("formalized_content")
except Exception as e: except Exception as e:
if self.check_if_canceled("Google processing"):
return
last_e = e last_e = e
logging.exception(f"Google error: {e}") logging.exception(f"Google error: {e}")
time.sleep(self._param.delay_after_error) time.sleep(self._param.delay_after_error)

View File

@ -65,15 +65,25 @@ class GoogleScholar(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("GoogleScholar processing"):
return
if not kwargs.get("query"): if not kwargs.get("query"):
self.set_output("formalized_content", "") self.set_output("formalized_content", "")
return "" return ""
last_e = "" last_e = ""
for _ in range(self._param.max_retries+1): for _ in range(self._param.max_retries+1):
if self.check_if_canceled("GoogleScholar processing"):
return
try: try:
scholar_client = scholarly.search_pubs(kwargs["query"], patents=self._param.patents, year_low=self._param.year_low, scholar_client = scholarly.search_pubs(kwargs["query"], patents=self._param.patents, year_low=self._param.year_low,
year_high=self._param.year_high, sort_by=self._param.sort_by) year_high=self._param.year_high, sort_by=self._param.sort_by)
if self.check_if_canceled("GoogleScholar processing"):
return
self._retrieve_chunks(scholar_client, self._retrieve_chunks(scholar_client,
get_title=lambda r: r['bib']['title'], get_title=lambda r: r['bib']['title'],
get_url=lambda r: r["pub_url"], get_url=lambda r: r["pub_url"],
@ -82,6 +92,9 @@ class GoogleScholar(ToolBase, ABC):
self.set_output("json", list(scholar_client)) self.set_output("json", list(scholar_client))
return self.output("formalized_content") return self.output("formalized_content")
except Exception as e: except Exception as e:
if self.check_if_canceled("GoogleScholar processing"):
return
last_e = e last_e = e
logging.exception(f"GoogleScholar error: {e}") logging.exception(f"GoogleScholar error: {e}")
time.sleep(self._param.delay_after_error) time.sleep(self._param.delay_after_error)

View File

@ -50,6 +50,9 @@ class Jin10(ComponentBase, ABC):
component_name = "Jin10" component_name = "Jin10"
def _run(self, history, **kwargs): def _run(self, history, **kwargs):
if self.check_if_canceled("Jin10 processing"):
return
ans = self.get_input() ans = self.get_input()
ans = " - ".join(ans["content"]) if "content" in ans else "" ans = " - ".join(ans["content"]) if "content" in ans else ""
if not ans: if not ans:
@ -58,6 +61,9 @@ class Jin10(ComponentBase, ABC):
jin10_res = [] jin10_res = []
headers = {'secret-key': self._param.secret_key} headers = {'secret-key': self._param.secret_key}
try: try:
if self.check_if_canceled("Jin10 processing"):
return
if self._param.type == "flash": if self._param.type == "flash":
params = { params = {
'category': self._param.flash_type, 'category': self._param.flash_type,
@ -69,6 +75,8 @@ class Jin10(ComponentBase, ABC):
headers=headers, data=json.dumps(params)) headers=headers, data=json.dumps(params))
response = response.json() response = response.json()
for i in response['data']: for i in response['data']:
if self.check_if_canceled("Jin10 processing"):
return
jin10_res.append({"content": i['data']['content']}) jin10_res.append({"content": i['data']['content']})
if self._param.type == "calendar": if self._param.type == "calendar":
params = { params = {
@ -79,6 +87,8 @@ class Jin10(ComponentBase, ABC):
headers=headers, data=json.dumps(params)) headers=headers, data=json.dumps(params))
response = response.json() response = response.json()
if self.check_if_canceled("Jin10 processing"):
return
jin10_res.append({"content": pd.DataFrame(response['data']).to_markdown()}) jin10_res.append({"content": pd.DataFrame(response['data']).to_markdown()})
if self._param.type == "symbols": if self._param.type == "symbols":
params = { params = {
@ -90,8 +100,12 @@ class Jin10(ComponentBase, ABC):
url='https://open-data-api.jin10.com/data-api/' + self._param.symbols_datatype + '?type=' + self._param.symbols_type, url='https://open-data-api.jin10.com/data-api/' + self._param.symbols_datatype + '?type=' + self._param.symbols_type,
headers=headers, data=json.dumps(params)) headers=headers, data=json.dumps(params))
response = response.json() response = response.json()
if self.check_if_canceled("Jin10 processing"):
return
if self._param.symbols_datatype == "symbols": if self._param.symbols_datatype == "symbols":
for i in response['data']: for i in response['data']:
if self.check_if_canceled("Jin10 processing"):
return
i['Commodity Code'] = i['c'] i['Commodity Code'] = i['c']
i['Stock Exchange'] = i['e'] i['Stock Exchange'] = i['e']
i['Commodity Name'] = i['n'] i['Commodity Name'] = i['n']
@ -99,6 +113,8 @@ class Jin10(ComponentBase, ABC):
del i['c'], i['e'], i['n'], i['t'] del i['c'], i['e'], i['n'], i['t']
if self._param.symbols_datatype == "quotes": if self._param.symbols_datatype == "quotes":
for i in response['data']: for i in response['data']:
if self.check_if_canceled("Jin10 processing"):
return
i['Selling Price'] = i['a'] i['Selling Price'] = i['a']
i['Buying Price'] = i['b'] i['Buying Price'] = i['b']
i['Commodity Code'] = i['c'] i['Commodity Code'] = i['c']
@ -120,8 +136,12 @@ class Jin10(ComponentBase, ABC):
url='https://open-data-api.jin10.com/data-api/news', url='https://open-data-api.jin10.com/data-api/news',
headers=headers, data=json.dumps(params)) headers=headers, data=json.dumps(params))
response = response.json() response = response.json()
if self.check_if_canceled("Jin10 processing"):
return
jin10_res.append({"content": pd.DataFrame(response['data']).to_markdown()}) jin10_res.append({"content": pd.DataFrame(response['data']).to_markdown()})
except Exception as e: except Exception as e:
if self.check_if_canceled("Jin10 processing"):
return
return Jin10.be_output("**ERROR**: " + str(e)) return Jin10.be_output("**ERROR**: " + str(e))
if not jin10_res: if not jin10_res:

View File

@ -71,23 +71,40 @@ class PubMed(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("PubMed processing"):
return
if not kwargs.get("query"): if not kwargs.get("query"):
self.set_output("formalized_content", "") self.set_output("formalized_content", "")
return "" return ""
last_e = "" last_e = ""
for _ in range(self._param.max_retries+1): for _ in range(self._param.max_retries+1):
if self.check_if_canceled("PubMed processing"):
return
try: try:
Entrez.email = self._param.email Entrez.email = self._param.email
pubmedids = Entrez.read(Entrez.esearch(db='pubmed', retmax=self._param.top_n, term=kwargs["query"]))['IdList'] pubmedids = Entrez.read(Entrez.esearch(db='pubmed', retmax=self._param.top_n, term=kwargs["query"]))['IdList']
if self.check_if_canceled("PubMed processing"):
return
pubmedcnt = ET.fromstring(re.sub(r'<(/?)b>|<(/?)i>', '', Entrez.efetch(db='pubmed', id=",".join(pubmedids), pubmedcnt = ET.fromstring(re.sub(r'<(/?)b>|<(/?)i>', '', Entrez.efetch(db='pubmed', id=",".join(pubmedids),
retmode="xml").read().decode("utf-8"))) retmode="xml").read().decode("utf-8")))
if self.check_if_canceled("PubMed processing"):
return
self._retrieve_chunks(pubmedcnt.findall("PubmedArticle"), self._retrieve_chunks(pubmedcnt.findall("PubmedArticle"),
get_title=lambda child: child.find("MedlineCitation").find("Article").find("ArticleTitle").text, get_title=lambda child: child.find("MedlineCitation").find("Article").find("ArticleTitle").text,
get_url=lambda child: "https://pubmed.ncbi.nlm.nih.gov/" + child.find("MedlineCitation").find("PMID").text, get_url=lambda child: "https://pubmed.ncbi.nlm.nih.gov/" + child.find("MedlineCitation").find("PMID").text,
get_content=lambda child: self._format_pubmed_content(child),) get_content=lambda child: self._format_pubmed_content(child),)
return self.output("formalized_content") return self.output("formalized_content")
except Exception as e: except Exception as e:
if self.check_if_canceled("PubMed processing"):
return
last_e = e last_e = e
logging.exception(f"PubMed error: {e}") logging.exception(f"PubMed error: {e}")
time.sleep(self._param.delay_after_error) time.sleep(self._param.delay_after_error)

View File

@ -58,12 +58,18 @@ class QWeather(ComponentBase, ABC):
component_name = "QWeather" component_name = "QWeather"
def _run(self, history, **kwargs): def _run(self, history, **kwargs):
if self.check_if_canceled("Qweather processing"):
return
ans = self.get_input() ans = self.get_input()
ans = "".join(ans["content"]) if "content" in ans else "" ans = "".join(ans["content"]) if "content" in ans else ""
if not ans: if not ans:
return QWeather.be_output("") return QWeather.be_output("")
try: try:
if self.check_if_canceled("Qweather processing"):
return
response = requests.get( response = requests.get(
url="https://geoapi.qweather.com/v2/city/lookup?location=" + ans + "&key=" + self._param.web_apikey).json() url="https://geoapi.qweather.com/v2/city/lookup?location=" + ans + "&key=" + self._param.web_apikey).json()
if response["code"] == "200": if response["code"] == "200":
@ -71,16 +77,23 @@ class QWeather(ComponentBase, ABC):
else: else:
return QWeather.be_output("**Error**" + self._param.error_code[response["code"]]) return QWeather.be_output("**Error**" + self._param.error_code[response["code"]])
if self.check_if_canceled("Qweather processing"):
return
base_url = "https://api.qweather.com/v7/" if self._param.user_type == 'paid' else "https://devapi.qweather.com/v7/" base_url = "https://api.qweather.com/v7/" if self._param.user_type == 'paid' else "https://devapi.qweather.com/v7/"
if self._param.type == "weather": if self._param.type == "weather":
url = base_url + "weather/" + self._param.time_period + "?location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang url = base_url + "weather/" + self._param.time_period + "?location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang
response = requests.get(url=url).json() response = requests.get(url=url).json()
if self.check_if_canceled("Qweather processing"):
return
if response["code"] == "200": if response["code"] == "200":
if self._param.time_period == "now": if self._param.time_period == "now":
return QWeather.be_output(str(response["now"])) return QWeather.be_output(str(response["now"]))
else: else:
qweather_res = [{"content": str(i) + "\n"} for i in response["daily"]] qweather_res = [{"content": str(i) + "\n"} for i in response["daily"]]
if self.check_if_canceled("Qweather processing"):
return
if not qweather_res: if not qweather_res:
return QWeather.be_output("") return QWeather.be_output("")
@ -92,6 +105,8 @@ class QWeather(ComponentBase, ABC):
elif self._param.type == "indices": elif self._param.type == "indices":
url = base_url + "indices/1d?type=0&location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang url = base_url + "indices/1d?type=0&location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang
response = requests.get(url=url).json() response = requests.get(url=url).json()
if self.check_if_canceled("Qweather processing"):
return
if response["code"] == "200": if response["code"] == "200":
indices_res = response["daily"][0]["date"] + "\n" + "\n".join( indices_res = response["daily"][0]["date"] + "\n" + "\n".join(
[i["name"] + ": " + i["category"] + ", " + i["text"] for i in response["daily"]]) [i["name"] + ": " + i["category"] + ", " + i["text"] for i in response["daily"]])
@ -103,9 +118,13 @@ class QWeather(ComponentBase, ABC):
elif self._param.type == "airquality": elif self._param.type == "airquality":
url = base_url + "air/now?location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang url = base_url + "air/now?location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang
response = requests.get(url=url).json() response = requests.get(url=url).json()
if self.check_if_canceled("Qweather processing"):
return
if response["code"] == "200": if response["code"] == "200":
return QWeather.be_output(str(response["now"])) return QWeather.be_output(str(response["now"]))
else: else:
return QWeather.be_output("**Error**" + self._param.error_code[response["code"]]) return QWeather.be_output("**Error**" + self._param.error_code[response["code"]])
except Exception as e: except Exception as e:
if self.check_if_canceled("Qweather processing"):
return
return QWeather.be_output("**Error**" + str(e)) return QWeather.be_output("**Error**" + str(e))

View File

@ -82,8 +82,12 @@ class Retrieval(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("Retrieval processing"):
return
if not kwargs.get("query"): if not kwargs.get("query"):
self.set_output("formalized_content", self._param.empty_response) self.set_output("formalized_content", self._param.empty_response)
return
kb_ids: list[str] = [] kb_ids: list[str] = []
for id in self._param.kb_ids: for id in self._param.kb_ids:
@ -122,7 +126,7 @@ class Retrieval(ToolBase, ABC):
vars = self.get_input_elements_from_text(kwargs["query"]) vars = self.get_input_elements_from_text(kwargs["query"])
vars = {k:o["value"] for k,o in vars.items()} vars = {k:o["value"] for k,o in vars.items()}
query = self.string_format(kwargs["query"], vars) query = self.string_format(kwargs["query"], vars)
doc_ids=[] doc_ids=[]
if self._param.meta_data_filter!={}: if self._param.meta_data_filter!={}:
metas = DocumentService.get_meta_by_kbs(kb_ids) metas = DocumentService.get_meta_by_kbs(kb_ids)
@ -184,9 +188,14 @@ class Retrieval(ToolBase, ABC):
rerank_mdl=rerank_mdl, rerank_mdl=rerank_mdl,
rank_feature=label_question(query, kbs), rank_feature=label_question(query, kbs),
) )
if self.check_if_canceled("Retrieval processing"):
return
if self._param.toc_enhance: if self._param.toc_enhance:
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n) cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
if self.check_if_canceled("Retrieval processing"):
return
if cks: if cks:
kbinfos["chunks"] = cks kbinfos["chunks"] = cks
if self._param.use_kg: if self._param.use_kg:
@ -195,6 +204,8 @@ class Retrieval(ToolBase, ABC):
kb_ids, kb_ids,
embd_mdl, embd_mdl,
LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)) LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT))
if self.check_if_canceled("Retrieval processing"):
return
if ck["content_with_weight"]: if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck) kbinfos["chunks"].insert(0, ck)
else: else:
@ -202,6 +213,8 @@ class Retrieval(ToolBase, ABC):
if self._param.use_kg and kbs: if self._param.use_kg and kbs:
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT)) ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
if self.check_if_canceled("Retrieval processing"):
return
if ck["content_with_weight"]: if ck["content_with_weight"]:
ck["content"] = ck["content_with_weight"] ck["content"] = ck["content_with_weight"]
del ck["content_with_weight"] del ck["content_with_weight"]

View File

@ -79,6 +79,9 @@ class SearXNG(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("SearXNG processing"):
return
# Gracefully handle try-run without inputs # Gracefully handle try-run without inputs
query = kwargs.get("query") query = kwargs.get("query")
if not query or not isinstance(query, str) or not query.strip(): if not query or not isinstance(query, str) or not query.strip():
@ -93,6 +96,9 @@ class SearXNG(ToolBase, ABC):
last_e = "" last_e = ""
for _ in range(self._param.max_retries+1): for _ in range(self._param.max_retries+1):
if self.check_if_canceled("SearXNG processing"):
return
try: try:
search_params = { search_params = {
'q': query, 'q': query,
@ -110,6 +116,9 @@ class SearXNG(ToolBase, ABC):
) )
response.raise_for_status() response.raise_for_status()
if self.check_if_canceled("SearXNG processing"):
return
data = response.json() data = response.json()
if not data or not isinstance(data, dict): if not data or not isinstance(data, dict):
@ -121,6 +130,9 @@ class SearXNG(ToolBase, ABC):
results = results[:self._param.top_n] results = results[:self._param.top_n]
if self.check_if_canceled("SearXNG processing"):
return
self._retrieve_chunks(results, self._retrieve_chunks(results,
get_title=lambda r: r.get("title", ""), get_title=lambda r: r.get("title", ""),
get_url=lambda r: r.get("url", ""), get_url=lambda r: r.get("url", ""),
@ -130,10 +142,16 @@ class SearXNG(ToolBase, ABC):
return self.output("formalized_content") return self.output("formalized_content")
except requests.RequestException as e: except requests.RequestException as e:
if self.check_if_canceled("SearXNG processing"):
return
last_e = f"Network error: {e}" last_e = f"Network error: {e}"
logging.exception(f"SearXNG network error: {e}") logging.exception(f"SearXNG network error: {e}")
time.sleep(self._param.delay_after_error) time.sleep(self._param.delay_after_error)
except Exception as e: except Exception as e:
if self.check_if_canceled("SearXNG processing"):
return
last_e = str(e) last_e = str(e)
logging.exception(f"SearXNG error: {e}") logging.exception(f"SearXNG error: {e}")
time.sleep(self._param.delay_after_error) time.sleep(self._param.delay_after_error)

View File

@ -103,6 +103,9 @@ class TavilySearch(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("TavilySearch processing"):
return
if not kwargs.get("query"): if not kwargs.get("query"):
self.set_output("formalized_content", "") self.set_output("formalized_content", "")
return "" return ""
@ -113,10 +116,16 @@ class TavilySearch(ToolBase, ABC):
if fld not in kwargs: if fld not in kwargs:
kwargs[fld] = getattr(self._param, fld) kwargs[fld] = getattr(self._param, fld)
for _ in range(self._param.max_retries+1): for _ in range(self._param.max_retries+1):
if self.check_if_canceled("TavilySearch processing"):
return
try: try:
kwargs["include_images"] = False kwargs["include_images"] = False
kwargs["include_raw_content"] = False kwargs["include_raw_content"] = False
res = self.tavily_client.search(**kwargs) res = self.tavily_client.search(**kwargs)
if self.check_if_canceled("TavilySearch processing"):
return
self._retrieve_chunks(res["results"], self._retrieve_chunks(res["results"],
get_title=lambda r: r["title"], get_title=lambda r: r["title"],
get_url=lambda r: r["url"], get_url=lambda r: r["url"],
@ -125,6 +134,9 @@ class TavilySearch(ToolBase, ABC):
self.set_output("json", res["results"]) self.set_output("json", res["results"])
return self.output("formalized_content") return self.output("formalized_content")
except Exception as e: except Exception as e:
if self.check_if_canceled("TavilySearch processing"):
return
last_e = e last_e = e
logging.exception(f"Tavily error: {e}") logging.exception(f"Tavily error: {e}")
time.sleep(self._param.delay_after_error) time.sleep(self._param.delay_after_error)
@ -201,6 +213,9 @@ class TavilyExtract(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("TavilyExtract processing"):
return
self.tavily_client = TavilyClient(api_key=self._param.api_key) self.tavily_client = TavilyClient(api_key=self._param.api_key)
last_e = None last_e = None
for fld in ["urls", "extract_depth", "format"]: for fld in ["urls", "extract_depth", "format"]:
@ -209,12 +224,21 @@ class TavilyExtract(ToolBase, ABC):
if kwargs.get("urls") and isinstance(kwargs["urls"], str): if kwargs.get("urls") and isinstance(kwargs["urls"], str):
kwargs["urls"] = kwargs["urls"].split(",") kwargs["urls"] = kwargs["urls"].split(",")
for _ in range(self._param.max_retries+1): for _ in range(self._param.max_retries+1):
if self.check_if_canceled("TavilyExtract processing"):
return
try: try:
kwargs["include_images"] = False kwargs["include_images"] = False
res = self.tavily_client.extract(**kwargs) res = self.tavily_client.extract(**kwargs)
if self.check_if_canceled("TavilyExtract processing"):
return
self.set_output("json", res["results"]) self.set_output("json", res["results"])
return self.output("json") return self.output("json")
except Exception as e: except Exception as e:
if self.check_if_canceled("TavilyExtract processing"):
return
last_e = e last_e = e
logging.exception(f"Tavily error: {e}") logging.exception(f"Tavily error: {e}")
if last_e: if last_e:

View File

@ -43,12 +43,18 @@ class TuShare(ComponentBase, ABC):
component_name = "TuShare" component_name = "TuShare"
def _run(self, history, **kwargs): def _run(self, history, **kwargs):
if self.check_if_canceled("TuShare processing"):
return
ans = self.get_input() ans = self.get_input()
ans = ",".join(ans["content"]) if "content" in ans else "" ans = ",".join(ans["content"]) if "content" in ans else ""
if not ans: if not ans:
return TuShare.be_output("") return TuShare.be_output("")
try: try:
if self.check_if_canceled("TuShare processing"):
return
tus_res = [] tus_res = []
params = { params = {
"api_name": "news", "api_name": "news",
@ -58,12 +64,18 @@ class TuShare(ComponentBase, ABC):
} }
response = requests.post(url="http://api.tushare.pro", data=json.dumps(params).encode('utf-8')) response = requests.post(url="http://api.tushare.pro", data=json.dumps(params).encode('utf-8'))
response = response.json() response = response.json()
if self.check_if_canceled("TuShare processing"):
return
if response['code'] != 0: if response['code'] != 0:
return TuShare.be_output(response['msg']) return TuShare.be_output(response['msg'])
df = pd.DataFrame(response['data']['items']) df = pd.DataFrame(response['data']['items'])
df.columns = response['data']['fields'] df.columns = response['data']['fields']
if self.check_if_canceled("TuShare processing"):
return
tus_res.append({"content": (df[df['content'].str.contains(self._param.keyword, case=False)]).to_markdown()}) tus_res.append({"content": (df[df['content'].str.contains(self._param.keyword, case=False)]).to_markdown()})
except Exception as e: except Exception as e:
if self.check_if_canceled("TuShare processing"):
return
return TuShare.be_output("**ERROR**: " + str(e)) return TuShare.be_output("**ERROR**: " + str(e))
if not tus_res: if not tus_res:

View File

@ -70,19 +70,31 @@ class WenCai(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("WenCai processing"):
return
if not kwargs.get("query"): if not kwargs.get("query"):
self.set_output("report", "") self.set_output("report", "")
return "" return ""
last_e = "" last_e = ""
for _ in range(self._param.max_retries+1): for _ in range(self._param.max_retries+1):
if self.check_if_canceled("WenCai processing"):
return
try: try:
wencai_res = [] wencai_res = []
res = pywencai.get(query=kwargs["query"], query_type=self._param.query_type, perpage=self._param.top_n) res = pywencai.get(query=kwargs["query"], query_type=self._param.query_type, perpage=self._param.top_n)
if self.check_if_canceled("WenCai processing"):
return
if isinstance(res, pd.DataFrame): if isinstance(res, pd.DataFrame):
wencai_res.append(res.to_markdown()) wencai_res.append(res.to_markdown())
elif isinstance(res, dict): elif isinstance(res, dict):
for item in res.items(): for item in res.items():
if self.check_if_canceled("WenCai processing"):
return
if isinstance(item[1], list): if isinstance(item[1], list):
wencai_res.append(item[0] + "\n" + pd.DataFrame(item[1]).to_markdown()) wencai_res.append(item[0] + "\n" + pd.DataFrame(item[1]).to_markdown())
elif isinstance(item[1], str): elif isinstance(item[1], str):
@ -100,6 +112,9 @@ class WenCai(ToolBase, ABC):
self.set_output("report", "\n\n".join(wencai_res)) self.set_output("report", "\n\n".join(wencai_res))
return self.output("report") return self.output("report")
except Exception as e: except Exception as e:
if self.check_if_canceled("WenCai processing"):
return
last_e = e last_e = e
logging.exception(f"WenCai error: {e}") logging.exception(f"WenCai error: {e}")
time.sleep(self._param.delay_after_error) time.sleep(self._param.delay_after_error)

View File

@ -66,17 +66,26 @@ class Wikipedia(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("Wikipedia processing"):
return
if not kwargs.get("query"): if not kwargs.get("query"):
self.set_output("formalized_content", "") self.set_output("formalized_content", "")
return "" return ""
last_e = "" last_e = ""
for _ in range(self._param.max_retries+1): for _ in range(self._param.max_retries+1):
if self.check_if_canceled("Wikipedia processing"):
return
try: try:
wikipedia.set_lang(self._param.language) wikipedia.set_lang(self._param.language)
wiki_engine = wikipedia wiki_engine = wikipedia
pages = [] pages = []
for p in wiki_engine.search(kwargs["query"], results=self._param.top_n): for p in wiki_engine.search(kwargs["query"], results=self._param.top_n):
if self.check_if_canceled("Wikipedia processing"):
return
try: try:
pages.append(wikipedia.page(p)) pages.append(wikipedia.page(p))
except Exception: except Exception:
@ -87,6 +96,9 @@ class Wikipedia(ToolBase, ABC):
get_content=lambda r: r.summary) get_content=lambda r: r.summary)
return self.output("formalized_content") return self.output("formalized_content")
except Exception as e: except Exception as e:
if self.check_if_canceled("Wikipedia processing"):
return
last_e = e last_e = e
logging.exception(f"Wikipedia error: {e}") logging.exception(f"Wikipedia error: {e}")
time.sleep(self._param.delay_after_error) time.sleep(self._param.delay_after_error)

View File

@ -74,15 +74,24 @@ class YahooFinance(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
if self.check_if_canceled("YahooFinance processing"):
return
if not kwargs.get("stock_code"): if not kwargs.get("stock_code"):
self.set_output("report", "") self.set_output("report", "")
return "" return ""
last_e = "" last_e = ""
for _ in range(self._param.max_retries+1): for _ in range(self._param.max_retries+1):
if self.check_if_canceled("YahooFinance processing"):
return
yohoo_res = [] yohoo_res = []
try: try:
msft = yf.Ticker(kwargs["stock_code"]) msft = yf.Ticker(kwargs["stock_code"])
if self.check_if_canceled("YahooFinance processing"):
return
if self._param.info: if self._param.info:
yohoo_res.append("# Information:\n" + pd.Series(msft.info).to_markdown() + "\n") yohoo_res.append("# Information:\n" + pd.Series(msft.info).to_markdown() + "\n")
if self._param.history: if self._param.history:
@ -100,6 +109,9 @@ class YahooFinance(ToolBase, ABC):
self.set_output("report", "\n\n".join(yohoo_res)) self.set_output("report", "\n\n".join(yohoo_res))
return self.output("report") return self.output("report")
except Exception as e: except Exception as e:
if self.check_if_canceled("YahooFinance processing"):
return
last_e = e last_e = e
logging.exception(f"YahooFinance error: {e}") logging.exception(f"YahooFinance error: {e}")
time.sleep(self._param.delay_after_error) time.sleep(self._param.delay_after_error)

View File

@ -156,7 +156,7 @@ def run():
return get_json_result(data={"message_id": task_id}) return get_json_result(data={"message_id": task_id})
try: try:
canvas = Canvas(cvs.dsl, current_user.id, req["id"]) canvas = Canvas(cvs.dsl, current_user.id)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -168,8 +168,10 @@ def run():
cvs.dsl = json.loads(str(canvas)) cvs.dsl = json.loads(str(canvas))
UserCanvasService.update_by_id(req["id"], cvs.to_dict()) UserCanvasService.update_by_id(req["id"], cvs.to_dict())
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
canvas.cancel_task()
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": False}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 500, "message": str(e), "data": False}, ensure_ascii=False) + "\n\n"
resp = Response(sse(), mimetype="text/event-stream") resp = Response(sse(), mimetype="text/event-stream")
@ -177,6 +179,7 @@ def run():
resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no") resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
resp.call_on_close(lambda: canvas.cancel_task())
return resp return resp
@ -430,7 +433,7 @@ def test_db_connect():
catalog, schema = _parse_catalog_schema(req["database"]) catalog, schema = _parse_catalog_schema(req["database"])
if not catalog: if not catalog:
return server_error_response("For Trino, 'database' must be 'catalog.schema' or at least 'catalog'.") return server_error_response("For Trino, 'database' must be 'catalog.schema' or at least 'catalog'.")
http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http" http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http"
auth = None auth = None
@ -603,4 +606,4 @@ def download():
id = request.args.get("id") id = request.args.get("id")
created_by = request.args.get("created_by") created_by = request.args.get("created_by")
blob = FileService.get_blob(created_by, id) blob = FileService.get_blob(created_by, id)
return flask.make_response(blob) return flask.make_response(blob)

View File

@ -24,7 +24,6 @@ import time
import json_repair import json_repair
from api.db.services.canvas_service import UserCanvasService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
from common.connection_utils import timeout from common.connection_utils import timeout
@ -33,7 +32,6 @@ from common.log_utils import init_root_logger
from common.config_utils import show_configs from common.config_utils import show_configs
from graphrag.general.index import run_graphrag_for_kb from graphrag.general.index import run_graphrag_for_kb
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
from rag.flow.pipeline import Pipeline
from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging, run_toc_from_text from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging, run_toc_from_text
import logging import logging
import os import os
@ -478,6 +476,9 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
async def run_dataflow(task: dict): async def run_dataflow(task: dict):
from api.db.services.canvas_service import UserCanvasService
from rag.flow.pipeline import Pipeline
task_start_ts = timer() task_start_ts = timer()
dataflow_id = task["dataflow_id"] dataflow_id = task["dataflow_id"]
doc_id = task["doc_id"] doc_id = task["doc_id"]
@ -944,6 +945,7 @@ async def do_handle_task(task):
async def handle_task(): async def handle_task():
global DONE_TASKS, FAILED_TASKS global DONE_TASKS, FAILED_TASKS
redis_msg, task = await collect() redis_msg, task = await collect()
if not task: if not task: